From 2225ecf44277c82b0f2d408236235c772b5f5b10 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 11 Oct 2022 12:35:58 -0400 Subject: [PATCH] clean up the prunerFunc --- gtsam/hybrid/GaussianMixture.cpp | 19 ++++++++++--------- gtsam/hybrid/GaussianMixture.h | 11 +++++++++++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 064905d6b..244d52738 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -131,7 +131,7 @@ void GaussianMixture::print(const std::string &s, /* ************************************************************************* */ /// Return the DiscreteKey vector as a set. -static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { +std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { std::set s; s.insert(dkeys.begin(), dkeys.end()); return s; @@ -142,18 +142,19 @@ static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { * @brief Helper function to get the pruner functional. * * @param decisionTree The probability decision tree of only discrete keys. - * @param decisionTreeKeySet Set of DiscreteKeys in decisionTree. - * Pre-computed for efficiency. - * @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture. * @return std::function &, const GaussianConditional::shared_ptr &)> */ std::function &, const GaussianConditional::shared_ptr &)> -PrunerFunc(const DecisionTreeFactor &decisionTree, - const std::set &decisionTreeKeySet, - const std::set &gaussianMixtureKeySet) { - auto pruner = [&](const Assignment &choices, +GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { + // Get the discrete keys as sets for the decision tree + // and the gaussian mixture. + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); + + auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( + const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value @@ -202,7 +203,7 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet); + auto pruner = prunerFunc(decisionTree); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 75deb4d55..9792a8532 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture */ Sum asGaussianFactorGraphTree() const; + /** + * @brief Helper function to get the pruner functor. + * + * @param decisionTree The pruned discrete probability decision tree. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ + std::function &, const GaussianConditional::shared_ptr &)> + prunerFunc(const DecisionTreeFactor &decisionTree); + public: /// @name Constructors /// @{