From 9e737dbbd8efe75b1a183fc34983b8fab75e6250 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 3 Jun 2022 12:47:21 -0400 Subject: [PATCH] initial pruning method --- gtsam/hybrid/HybridBayesNet.cpp | 95 +++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 1292711d8..7d1834723 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -14,3 +14,98 @@ */ #include + +namespace gtsam { + +/* ************************************************************************* */ +/// Return the DiscreteKey vector as a set. +static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { + std::set s; + s.insert(dkeys.begin(), dkeys.end()); + return s; +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune( + const DecisionTreeFactor::shared_ptr &discreteFactor) const { + /* To Prune, we visitWith every leaf in the GaussianMixture. + * For each leaf, using the assignment we can check the discrete decision tree + * for 0.0 probability, then just set the leaf to a nullptr. + * + * We can later check the GaussianMixture for just nullptrs. + */ + + HybridBayesNet prunedBayesNetFragment; + + // Functional which loops over all assignments and create a set of + // GaussianConditionals + auto pruner = + [&](const Assignment &choices, + const GaussianFactor::shared_ptr &gf) -> GaussianFactor::shared_ptr { + // typecast so we can use this to get probability value + DiscreteValues values(choices); + + if ((*discreteFactor)(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return gf; + } + }; + + // Go through all the conditionals in the + // Bayes Net and prune them as per discreteFactor. + for (size_t i = 0; i < this->size(); i++) { + HybridConditional::shared_ptr conditional = this->at(i); + + GaussianMixture::shared_ptr gaussianMixture = + boost::dynamic_pointer_cast(conditional->inner()); + + if (gaussianMixture) { + // We may have mixtures with less discrete keys than discreteFactor so we + // skip those since the label assignment does not exist. + auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); + auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); + if (gmKeySet != dfKeySet) { + // Add the gaussianMixture which doesn't have to be pruned. + prunedBayesNetFragment.push_back( + boost::make_shared(gaussianMixture)); + continue; + } + + // The GaussianMixture stores all the conditionals and uneliminated + // factors in the factors tree. + auto gaussianMixtureTree = gaussianMixture->factors(); + + // Run the pruning to get a new, pruned tree + auto prunedFactors = gaussianMixtureTree.apply(pruner); + + DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); + // reverse keys to get a natural ordering + std::reverse(discreteKeys.begin(), discreteKeys.end()); + + // Convert to GaussianConditionals + auto prunedTree = GaussianMixture::Conditionals( + prunedFactors, [](const GaussianFactor::shared_ptr &factor) { + return boost::dynamic_pointer_cast(factor); + }); + + // Create the new gaussian mixture and add it to the bayes net. + auto prunedGaussianMixture = boost::make_shared( + gaussianMixture->frontals(), gaussianMixture->parents(), discreteKeys, + prunedTree); + + prunedBayesNetFragment.push_back( + boost::make_shared(prunedGaussianMixture)); + + } else { + // Add the non-GaussianMixture conditional + prunedBayesNetFragment.push_back(conditional); + } + } + + return prunedBayesNetFragment; +} + +} // namespace gtsam