initial pruning method
parent
b47cd9d97b
commit
9e737dbbd8
|
|
@ -14,3 +14,98 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
*
|
||||
* We can later check the GaussianMixture for just nullptrs.
|
||||
*/
|
||||
|
||||
HybridBayesNet prunedBayesNetFragment;
|
||||
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner =
|
||||
[&](const Assignment<Key> &choices,
|
||||
const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
if ((*discreteFactor)(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianFactor> null;
|
||||
return null;
|
||||
} else {
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per discreteFactor.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
|
||||
GaussianMixture::shared_ptr gaussianMixture =
|
||||
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
|
||||
|
||||
if (gaussianMixture) {
|
||||
// We may have mixtures with less discrete keys than discreteFactor so we
|
||||
// skip those since the label assignment does not exist.
|
||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
|
||||
if (gmKeySet != dfKeySet) {
|
||||
// Add the gaussianMixture which doesn't have to be pruned.
|
||||
prunedBayesNetFragment.push_back(
|
||||
boost::make_shared<HybridConditional>(gaussianMixture));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The GaussianMixture stores all the conditionals and uneliminated
|
||||
// factors in the factors tree.
|
||||
auto gaussianMixtureTree = gaussianMixture->factors();
|
||||
|
||||
// Run the pruning to get a new, pruned tree
|
||||
auto prunedFactors = gaussianMixtureTree.apply(pruner);
|
||||
|
||||
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
||||
// reverse keys to get a natural ordering
|
||||
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
||||
|
||||
// Convert to GaussianConditionals
|
||||
auto prunedTree = GaussianMixture::Conditionals(
|
||||
prunedFactors, [](const GaussianFactor::shared_ptr &factor) {
|
||||
return boost::dynamic_pointer_cast<GaussianConditional>(factor);
|
||||
});
|
||||
|
||||
// Create the new gaussian mixture and add it to the bayes net.
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys,
|
||||
prunedTree);
|
||||
|
||||
prunedBayesNetFragment.push_back(
|
||||
boost::make_shared<HybridConditional>(prunedGaussianMixture));
|
||||
|
||||
} else {
|
||||
// Add the non-GaussianMixture conditional
|
||||
prunedBayesNetFragment.push_back(conditional);
|
||||
}
|
||||
}
|
||||
|
||||
return prunedBayesNetFragment;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
Loading…
Reference in New Issue