From 3c50f9387c4b84d3e1627de5e179e680c29279f2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 29 Oct 2024 01:55:04 -0400 Subject: [PATCH] add pruning support for HybridNonlinearFactor --- gtsam/hybrid/HybridNonlinearFactor.cpp | 31 ++++++++++++++++++++++++++ gtsam/hybrid/HybridNonlinearFactor.h | 7 ++++++ 2 files changed, 38 insertions(+) diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 6ffb95511..48c327156 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -16,6 +16,7 @@ * @date Sep 12, 2024 */ +#include #include #include #include @@ -202,4 +203,34 @@ std::shared_ptr HybridNonlinearFactor::linearize( linearized_factors); } +/* *******************************************************************************/ +HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune( + const DecisionTreeFactor& discreteProbs) const { + // Find keys in discreteProbs.keys() but not in this->keys(): + std::set mine(this->keys().begin(), this->keys().end()); + std::set theirs(discreteProbs.keys().begin(), + discreteProbs.keys().end()); + std::vector diff; + std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), + std::back_inserter(diff)); + + // Find maximum probability value for every combination of our keys. + Ordering keys(diff); + auto max = discreteProbs.max(keys); + + // Check the max value for every combination of our keys. + // If the max value is 0.0, we can prune the corresponding conditional. + auto pruner = + [&](const Assignment& choices, + const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair { + if (max->evaluate(choices) == 0.0) + return {nullptr, std::numeric_limits::infinity()}; + else + return pair; + }; + + FactorValuePairs prunedFactors = factors().apply(pruner); + return std::make_shared(discreteKeys(), prunedFactors); +} + } // namespace gtsam \ No newline at end of file diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 325fa3eaa..e264b1d10 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -166,6 +166,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { /// @} + /// Getter for NonlinearFactor decision tree + const FactorValuePairs& factors() const { return factors_; } + /// Linearize specific nonlinear factors based on the assignment in /// discreteValues. GaussianFactor::shared_ptr linearize( @@ -176,6 +179,10 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { std::shared_ptr linearize( const Values& continuousValues) const; + /// Prune this factor based on the discrete probabilities. + HybridNonlinearFactor::shared_ptr prune( + const DecisionTreeFactor& discreteProbs) const; + private: /// Helper struct to assist private constructor below. struct ConstructorHelper;