add pruning to HybridBayesNet

release/4.3a0
Varun Agrawal 2022-06-03 13:25:19 -04:00
parent ad77a45a0d
commit c2e5061b71
2 changed files with 22 additions and 18 deletions

View File

@ -39,18 +39,18 @@ HybridBayesNet HybridBayesNet::prune(
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner =
[&](const Assignment<Key> &choices,
const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr {
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianFactor> null;
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return gf;
return conditional;
}
};
@ -74,28 +74,28 @@ HybridBayesNet HybridBayesNet::prune(
continue;
}
// The GaussianMixture stores all the conditionals and uneliminated
// factors in the factors tree.
auto gaussianMixtureTree = gaussianMixture->factors();
// Run the pruning to get a new, pruned tree
auto prunedFactors = gaussianMixtureTree.apply(pruner);
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert to GaussianConditionals
auto prunedTree = GaussianMixture::Conditionals(
prunedFactors, [](const GaussianFactor::shared_ptr &factor) {
return boost::dynamic_pointer_cast<GaussianConditional>(factor);
});
// Convert from boost::iterator_range to std::vector.
std::vector<Key> frontals, parents;
for (Key key : gaussianMixture->frontals()) {
frontals.push_back(key);
}
for (Key key : gaussianMixture->parents()) {
parents.push_back(key);
}
// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys,
prunedTree);
frontals, parents, discreteKeys, prunedTree);
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(prunedGaussianMixture));

View File

@ -17,6 +17,7 @@
#pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesNet.h>
@ -35,7 +36,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using sharedConditional = boost::shared_ptr<ConditionalType>;
/** Construct empty bayes net */
HybridBayesNet() = default;
HybridBayesNet() : Base() {}
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
};
} // namespace gtsam