diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 7d1834723..d3c77d83e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -39,18 +39,18 @@ HybridBayesNet HybridBayesNet::prune( // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = - [&](const Assignment &choices, - const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr { + auto pruner = [&](const Assignment &choices, + const GaussianConditional::shared_ptr &conditional) + -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value DiscreteValues values(choices); if ((*discreteFactor)(values) == 0.0) { // empty aka null pointer - boost::shared_ptr null; + boost::shared_ptr null; return null; } else { - return gf; + return conditional; } }; @@ -74,28 +74,28 @@ HybridBayesNet HybridBayesNet::prune( continue; } - // The GaussianMixture stores all the conditionals and uneliminated - // factors in the factors tree. - auto gaussianMixtureTree = gaussianMixture->factors(); - // Run the pruning to get a new, pruned tree - auto prunedFactors = gaussianMixtureTree.apply(pruner); + GaussianMixture::Conditionals prunedTree = + gaussianMixture->conditionals().apply(pruner); DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); // reverse keys to get a natural ordering std::reverse(discreteKeys.begin(), discreteKeys.end()); - // Convert to GaussianConditionals - auto prunedTree = GaussianMixture::Conditionals( - prunedFactors, [](const GaussianFactor::shared_ptr &factor) { - return boost::dynamic_pointer_cast(factor); - }); + // Convert from boost::iterator_range to std::vector. + std::vector frontals, parents; + for (Key key : gaussianMixture->frontals()) { + frontals.push_back(key); + } + for (Key key : gaussianMixture->parents()) { + parents.push_back(key); + } // Create the new gaussian mixture and add it to the bayes net. auto prunedGaussianMixture = boost::make_shared( - gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys, - prunedTree); + frontals, parents, discreteKeys, prunedTree); + // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back( boost::make_shared(prunedGaussianMixture)); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 43eead280..834c4c3ef 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -35,7 +36,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using sharedConditional = boost::shared_ptr; /** Construct empty bayes net */ - HybridBayesNet() = default; + HybridBayesNet() : Base() {} + + HybridBayesNet prune( + const DecisionTreeFactor::shared_ptr &discreteFactor) const; }; } // namespace gtsam