add pruning to HybridBayesNet
parent
ad77a45a0d
commit
c2e5061b71
|
|
@ -39,18 +39,18 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner =
|
||||
[&](const Assignment<Key> &choices,
|
||||
const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr {
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
if ((*discreteFactor)(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianFactor> null;
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return gf;
|
||||
return conditional;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -74,28 +74,28 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
continue;
|
||||
}
|
||||
|
||||
// The GaussianMixture stores all the conditionals and uneliminated
|
||||
// factors in the factors tree.
|
||||
auto gaussianMixtureTree = gaussianMixture->factors();
|
||||
|
||||
// Run the pruning to get a new, pruned tree
|
||||
auto prunedFactors = gaussianMixtureTree.apply(pruner);
|
||||
GaussianMixture::Conditionals prunedTree =
|
||||
gaussianMixture->conditionals().apply(pruner);
|
||||
|
||||
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
||||
// reverse keys to get a natural ordering
|
||||
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
||||
|
||||
// Convert to GaussianConditionals
|
||||
auto prunedTree = GaussianMixture::Conditionals(
|
||||
prunedFactors, [](const GaussianFactor::shared_ptr &factor) {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(factor);
|
||||
});
|
||||
// Convert from boost::iterator_range to std::vector.
|
||||
std::vector<Key> frontals, parents;
|
||||
for (Key key : gaussianMixture->frontals()) {
|
||||
frontals.push_back(key);
|
||||
}
|
||||
for (Key key : gaussianMixture->parents()) {
|
||||
parents.push_back(key);
|
||||
}
|
||||
|
||||
// Create the new gaussian mixture and add it to the bayes net.
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys,
|
||||
prunedTree);
|
||||
frontals, parents, discreteKeys, prunedTree);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(
|
||||
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
||||
|
|
@ -35,7 +36,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||
|
||||
/** Construct empty bayes net */
|
||||
HybridBayesNet() = default;
|
||||
HybridBayesNet() : Base() {}
|
||||
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
Loading…
Reference in New Issue