From e15c44ec5c711f9efae695c0d03fa3b8ed24bbb8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 22:56:41 -0700 Subject: [PATCH] Make prune functional --- gtsam/hybrid/HybridGaussianConditional.cpp | 50 ++++++++++++++++++++-- gtsam/hybrid/HybridGaussianConditional.h | 4 +- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 1db13e95b..3c5130f42 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -342,13 +342,57 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) { } /* *******************************************************************************/ -void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) { +HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( + const DecisionTreeFactor &discreteProbs) const { + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); + auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = prunerFunc(discreteProbs); + auto pruner = [&](const Assignment &choices, + const GaussianConditional::shared_ptr &conditional) + -> GaussianConditional::shared_ptr { + // typecast so we can use this to get probability value + const DiscreteValues values(choices); + + // Case where the hybrid gaussian conditional has the same + // discrete keys as the decision tree. + if (hybridGaussianCondKeySet == discreteProbsKeySet) { + if (discreteProbs(values) == 0.0) { + // empty aka null pointer + std::shared_ptr null; + return null; + } else { + return conditional; + } + } else { + std::vector set_diff; + std::set_difference( + discreteProbsKeySet.begin(), discreteProbsKeySet.end(), + hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if (discreteProbs(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; + } + }; auto pruned_conditionals = conditionals_.apply(pruner); - conditionals_.root_ = pruned_conditionals.root_; + return std::make_shared(discreteKeys(), + pruned_conditionals); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 68c63e7bd..ede748b16 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -225,8 +225,10 @@ class GTSAM_EXPORT HybridGaussianConditional * `discreteProbs`. * * @param discreteProbs A pruned set of probabilities for the discrete keys. + * @return Shared pointer to possibly a pruned HybridGaussianConditional */ - void prune(const DecisionTreeFactor &discreteProbs); + HybridGaussianConditional::shared_ptr prune( + const DecisionTreeFactor &discreteProbs) const; /// @}