From a00bcbcac92d45a5dcc06d7c74d7f25ef9feb32d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 10 Oct 2022 16:03:36 -0400 Subject: [PATCH] PrunerFunc helper function --- gtsam/hybrid/HybridBayesNet.cpp | 65 +++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 788970790..163e77e47 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -49,6 +49,38 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { return boost::make_shared(dtFactor); } +/** + * @brief Helper function to get the pruner functional. + * + * @param probDecisionTree The probability decision tree of only discrete keys. + * @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree. + * Pre-computed for efficiency. + * @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ +std::function &, const GaussianConditional::shared_ptr &)> +PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree, + const std::set &discreteFactorKeySet, + const std::set &gaussianMixtureKeySet) { + auto pruner = [&](const Assignment &choices, + const GaussianConditional::shared_ptr &conditional) + -> GaussianConditional::shared_ptr { + // typecast so we can use this to get probability value + DiscreteValues values(choices); + + if ((*probDecisionTree)(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return conditional; + } + }; + return pruner; +} + /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Get the decision tree of only the discrete keys @@ -57,6 +89,8 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { boost::make_shared( discreteConditionals->prune(maxNrLeaves)); + auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); + /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -66,23 +100,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { HybridBayesNet prunedBayesNetFragment; - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - DiscreteValues values(choices); - - if ((*discreteFactor)(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; - } else { - return conditional; - } - }; - // Go through all the conditionals in the // Bayes Net and prune them as per discreteFactor. for (size_t i = 0; i < this->size(); i++) { @@ -92,17 +109,19 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { boost::dynamic_pointer_cast(conditional->inner()); if (gaussianMixture) { - // We may have mixtures with less discrete keys than discreteFactor so we - // skip those since the label assignment does not exist. + // We may have mixtures with less discrete keys than discreteFactor so + // we skip those since the label assignment does not exist. auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); - auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); - if (gmKeySet != dfKeySet) { + if (gmKeySet != discreteFactorKeySet) { // Add the gaussianMixture which doesn't have to be pruned. prunedBayesNetFragment.push_back( boost::make_shared(gaussianMixture)); continue; } + // Get the pruner function. + auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet); + // Run the pruning to get a new, pruned tree GaussianMixture::Conditionals prunedTree = gaussianMixture->conditionals().apply(pruner); @@ -173,7 +192,7 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } -/* *******************************************************************************/ +/* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Solve for the MPE DiscreteBayesNet discrete_bn; @@ -190,7 +209,7 @@ HybridValues HybridBayesNet::optimize() const { return HybridValues(mpe, gbn.optimize()); } -/* *******************************************************************************/ +/* ************************************************************************* */ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { GaussianBayesNet gbn = this->choose(assignment); return gbn.optimize();