add pruning to HybridBayesNet
parent
ad77a45a0d
commit
c2e5061b71
|
|
@ -39,18 +39,18 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
|
|
||||||
// Functional which loops over all assignments and create a set of
|
// Functional which loops over all assignments and create a set of
|
||||||
// GaussianConditionals
|
// GaussianConditionals
|
||||||
auto pruner =
|
auto pruner = [&](const Assignment<Key> &choices,
|
||||||
[&](const Assignment<Key> &choices,
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
|
|
||||||
if ((*discreteFactor)(values) == 0.0) {
|
if ((*discreteFactor)(values) == 0.0) {
|
||||||
// empty aka null pointer
|
// empty aka null pointer
|
||||||
boost::shared_ptr<GaussianFactor> null;
|
boost::shared_ptr<GaussianConditional> null;
|
||||||
return null;
|
return null;
|
||||||
} else {
|
} else {
|
||||||
return gf;
|
return conditional;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -74,28 +74,28 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The GaussianMixture stores all the conditionals and uneliminated
|
|
||||||
// factors in the factors tree.
|
|
||||||
auto gaussianMixtureTree = gaussianMixture->factors();
|
|
||||||
|
|
||||||
// Run the pruning to get a new, pruned tree
|
// Run the pruning to get a new, pruned tree
|
||||||
auto prunedFactors = gaussianMixtureTree.apply(pruner);
|
GaussianMixture::Conditionals prunedTree =
|
||||||
|
gaussianMixture->conditionals().apply(pruner);
|
||||||
|
|
||||||
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
||||||
// reverse keys to get a natural ordering
|
// reverse keys to get a natural ordering
|
||||||
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
||||||
|
|
||||||
// Convert to GaussianConditionals
|
// Convert from boost::iterator_range to std::vector.
|
||||||
auto prunedTree = GaussianMixture::Conditionals(
|
std::vector<Key> frontals, parents;
|
||||||
prunedFactors, [](const GaussianFactor::shared_ptr &factor) {
|
for (Key key : gaussianMixture->frontals()) {
|
||||||
return boost::dynamic_pointer_cast<GaussianConditional>(factor);
|
frontals.push_back(key);
|
||||||
});
|
}
|
||||||
|
for (Key key : gaussianMixture->parents()) {
|
||||||
|
parents.push_back(key);
|
||||||
|
}
|
||||||
|
|
||||||
// Create the new gaussian mixture and add it to the bayes net.
|
// Create the new gaussian mixture and add it to the bayes net.
|
||||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
||||||
gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys,
|
frontals, parents, discreteKeys, prunedTree);
|
||||||
prunedTree);
|
|
||||||
|
|
||||||
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
prunedBayesNetFragment.push_back(
|
prunedBayesNetFragment.push_back(
|
||||||
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
|
||||||
|
|
@ -35,7 +36,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() : Base() {}
|
||||||
|
|
||||||
|
HybridBayesNet prune(
|
||||||
|
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue