add pruning to HybridBayesNet

release/4.3a0
Varun Agrawal 2022-06-03 13:25:19 -04:00
parent ad77a45a0d
commit c2e5061b71
2 changed files with 22 additions and 18 deletions

View File

@ -39,18 +39,18 @@ HybridBayesNet HybridBayesNet::prune(
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = auto pruner = [&](const Assignment<Key> &choices,
[&](const Assignment<Key> &choices, const GaussianConditional::shared_ptr &conditional)
const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) { if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer // empty aka null pointer
boost::shared_ptr<GaussianFactor> null; boost::shared_ptr<GaussianConditional> null;
return null; return null;
} else { } else {
return gf; return conditional;
} }
}; };
@ -74,28 +74,28 @@ HybridBayesNet HybridBayesNet::prune(
continue; continue;
} }
// The GaussianMixture stores all the conditionals and uneliminated
// factors in the factors tree.
auto gaussianMixtureTree = gaussianMixture->factors();
// Run the pruning to get a new, pruned tree // Run the pruning to get a new, pruned tree
auto prunedFactors = gaussianMixtureTree.apply(pruner); GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering // reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end()); std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert to GaussianConditionals // Convert from boost::iterator_range to std::vector.
auto prunedTree = GaussianMixture::Conditionals( std::vector<Key> frontals, parents;
prunedFactors, [](const GaussianFactor::shared_ptr &factor) { for (Key key : gaussianMixture->frontals()) {
return boost::dynamic_pointer_cast<GaussianConditional>(factor); frontals.push_back(key);
}); }
for (Key key : gaussianMixture->parents()) {
parents.push_back(key);
}
// Create the new gaussian mixture and add it to the bayes net. // Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>( auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys, frontals, parents, discreteKeys, prunedTree);
prunedTree);
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(prunedGaussianMixture)); boost::make_shared<HybridConditional>(prunedGaussianMixture));

View File

@ -17,6 +17,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
@ -35,7 +36,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using sharedConditional = boost::shared_ptr<ConditionalType>; using sharedConditional = boost::shared_ptr<ConditionalType>;
/** Construct empty bayes net */ /** Construct empty bayes net */
HybridBayesNet() = default; HybridBayesNet() : Base() {}
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
}; };
} // namespace gtsam } // namespace gtsam