From 98cdf1193facd7715cd37de002fb50762a3d2e2a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 17:54:29 -0500 Subject: [PATCH 01/11] Fix pruning --- gtsam/hybrid/HybridBayesNet.cpp | 90 +++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e64284a94..2efb8030e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { +#if GTSAM_HYBRID_TIMING + gttic_(HybridPruning); +#endif // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune( // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { +#if GTSAM_HYBRID_TIMING + gttic_(DeadModeRemoval); +#endif + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); @@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } + GTSAM_PRINT(deadModesValues); - // Add the marginals for future factors - for (auto &&[key, _] : deadModesValues) { - result.push_back( - std::dynamic_pointer_cast(marginals(key))); - } - - } else { - result.emplace_shared(pruned); +#if GTSAM_HYBRID_TIMING + gttoc_(DeadModeRemoval); +#endif } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune( if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); + if (!prunedHybridGaussianConditional) { + GTSAM_PRINT(marginal); + GTSAM_PRINT(pruned); + throw std::runtime_error( + "A HybridGaussianConditional had all its conditionals pruned"); + } if (deadModeThreshold.has_value()) { - KeyVector deadKeys, conditionalDiscreteKeys; - for (const auto &kv : deadModesValues) { - deadKeys.push_back(kv.first); + const auto &discreteParents = + prunedHybridGaussianConditional->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = deadModesValues.find(key.first); + if (it != deadModesValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); } - for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { - conditionalDiscreteKeys.push_back(dkey.first); - } - // The discrete keys in the conditional are the same as the keys in the - // dead modes, then we just get the corresponding Gaussian conditional. - if (deadKeys == conditionalDiscreteKeys) { + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: result.push_back( - prunedHybridGaussianConditional->choose(deadModesValues)); + prunedHybridGaussianConditional->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = prunedHybridGaussianConditional->factors(); + for (auto &&[key, value] : deadModesValues) { + newTree = newTree.choose(key, value); + } + result.emplace_shared(liveParents, + newTree); } else { // Add as-is result.push_back(prunedHybridGaussianConditional); @@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune( // We ignore DiscreteConditional as they are already pruned and added. } +#if GTSAM_HYBRID_TIMING + gttoc_(HybridPruning); +#endif + + if (deadModeThreshold.has_value()) { + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + // for (auto &&[key, _] : deadModesValues) { + // result.push_back( + // std::dynamic_pointer_cast(marginals(key))); + // } + + } else { + result.emplace_shared(pruned); + } + return result; } From d17215c69b0837023e7069f7fd2cbfdc1359e3fb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 21:36:51 -0500 Subject: [PATCH 02/11] prototype choose --- .../tests/testHybridGaussianConditional.cpp | 95 +++++++++++++++---- 1 file changed, 78 insertions(+), 17 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 8bb83cac4..7c7fc3ea5 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -238,22 +238,27 @@ TEST(HybridGaussianConditional, Likelihood2) { EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); } +/* ************************************************************************* */ +namespace two_mode_measurement { +// Create a two key conditional: +const DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; +const std::vector gcs = { + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(1), 1), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(2), 2), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(3), 3), + GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(4), 4)}; +const HybridGaussianConditional::Conditionals conditionals(modes, gcs); +const auto hgc = + std::make_shared(modes, conditionals); +} // namespace two_mode_measurement + /* ************************************************************************* */ // Test pruning a HybridGaussianConditional with two discrete keys, based on a // DecisionTreeFactor with 3 keys: TEST(HybridGaussianConditional, Prune) { - // Create a two key conditional: - DiscreteKeys modes{{M(1), 2}, {M(2), 2}}; - std::vector gcs; - for (size_t i = 0; i < 4; i++) { - gcs.push_back( - GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1)); - } - auto empty = std::make_shared(); - HybridGaussianConditional::Conditionals conditionals(modes, gcs); - HybridGaussianConditional hgc(modes, conditionals); + using two_mode_measurement::hgc; - DiscreteKeys keys = modes; + DiscreteKeys keys = two_mode_measurement::modes; keys.push_back({M(3), 2}); { for (size_t i = 0; i < 8; i++) { @@ -262,7 +267,7 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -273,14 +278,14 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); // Check that the minimum negLogConstant is set correctly EXPECT_DOUBLES_EQUAL( - hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + hgc->conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), pruned->negLogConstant(), 1e-9); } { @@ -289,18 +294,74 @@ TEST(HybridGaussianConditional, Prune) { const DecisionTreeFactor decisionTreeFactor(keys, potentials); const auto pruned = - hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); + hgc->prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); // Check that the minimum negLogConstant is correct - EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); + EXPECT_DOUBLES_EQUAL(hgc->negLogConstant(), pruned->negLogConstant(), 1e-9); } } -/* ************************************************************************* +/* ************************************************************************* */ + +#include + +/** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. */ +HybridConditional::shared_ptr choose( + const HybridGaussianConditional::shared_ptr &self, + const DiscreteValues &discreteValues) { + const auto &discreteParents = self->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = discreteValues.find(key.first); + if (it != discreteValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); + } + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: + return std::make_shared(self->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = self->factors(); + for (auto &&[key, value] : discreteValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(liveParents, newTree)); + } else { + // Add as-is + return std::make_shared(self); + } +} + +/* ************************************************************************* */ +// Test the pruning and dead-mode removal. +TEST(HybridGaussianConditional, PrunePlus) { + using two_mode_measurement::hgc; // two discrete parents + + const HybridConditional::shared_ptr same = choose(hgc, {}); + EXPECT(same->isHybrid()); + EXPECT(same->asHybrid()->nrComponents() == 4); + + const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + EXPECT(oneParent->isHybrid()); + EXPECT(oneParent->asHybrid()->nrComponents() == 2); + + const HybridConditional::shared_ptr gaussian = + choose(hgc, {{M(1), 0}, {M(2), 1}}); + EXPECT(gaussian->asGaussian()); +} + +/* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); From 803dae75f31e9b9222a783602678457d303e1c82 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:04:56 -0500 Subject: [PATCH 03/11] filter and missingKeys --- gtsam/discrete/DiscreteValues.h | 82 ++++++++++++++++----- gtsam/discrete/tests/testDiscreteValues.cpp | 18 +++++ 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index df4ecdbff..7c73da681 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -68,28 +68,74 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x); // insert in base class; - std::pair insert( const value_type& value ){ + std::pair insert(const value_type& value) { return Base::insert(value); } /** - * Insert key-assignment pair. - * Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + * @brief Insert key-assignment pair. + * + * @param assignment The key-assignment pair to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const std::pair& assignment); - /** Insert all values from \c values. Throws an invalid_argument exception if - * any keys to be inserted are already used. */ + /** + * @brief Insert all values from another DiscreteValues object. + * + * @param values The DiscreteValues object containing values to insert. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::invalid_argument if any keys to be inserted are already used. + */ DiscreteValues& insert(const DiscreteValues& values); - /** For all key/value pairs in \c values, replace values with corresponding - * keys in this object with those in \c values. Throws std::out_of_range if - * any keys in \c values are not present in this object. */ + /** + * @brief Update values with corresponding keys from another DiscreteValues + * object. + * + * @param values The DiscreteValues object containing values to update. + * @return DiscreteValues& Reference to the updated DiscreteValues object. + * @throws std::out_of_range if any keys in values are not present in this + * object. + */ DiscreteValues& update(const DiscreteValues& values); /** - * @brief Return a vector of DiscreteValues, one for each possible - * combination of values. + * @brief Filter values by keys. + * + * @param keys The keys to filter by. + * @return DiscreteValues The filtered DiscreteValues object. + */ + DiscreteValues filter(const DiscreteKeys& keys) const { + DiscreteValues result; + for (const auto& [key, _] : keys) { + if (auto it = this->find(key); it != this->end()) + result[key] = it->second; + } + return result; + } + + /** + * @brief Return the keys that are not present in the DiscreteValues object. + * + * @param keys The keys to check for. + * @return DiscreteKeys Keys not present in the DiscreteValues object. + */ + DiscreteKeys missingKeys(const DiscreteKeys& keys) const { + DiscreteKeys result; + for (const auto& [key, cardinality] : keys) { + if (!this->contains(key)) result.emplace_back(key, cardinality); + } + return result; + } + + /** + * @brief Return a vector of DiscreteValues, one for each possible combination + * of values. + * + * @param keys The keys to generate the Cartesian product for. + * @return std::vector The vector of DiscreteValues. */ static std::vector CartesianProduct( const DiscreteKeys& keys) { @@ -135,14 +181,16 @@ inline std::vector cartesianProduct(const DiscreteKeys& keys) { } /// Free version of markdown. -std::string GTSAM_EXPORT markdown(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +markdown(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); /// Free version of html. -std::string GTSAM_EXPORT html(const DiscreteValues& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const DiscreteValues::Names& names = {}); +std::string GTSAM_EXPORT +html(const DiscreteValues& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteValues::Names& names = {}); // traits template <> diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp index 47989481c..4667bc5ea 100644 --- a/gtsam/discrete/tests/testDiscreteValues.cpp +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) { DiscreteValues(kExample).update({{12, 2}}))); } +/* ************************************************************************* */ +// Test DiscreteValues::filter +TEST(DiscreteValues, Filter) { + DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}}; + DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values + + EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys))); +} + +/* ************************************************************************* */ +// Test DiscreteValues::missingKeys +TEST(DiscreteValues, MissingKeys) { + DiscreteValues values = {{12, 1}, {5, 0}}; + DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing + + EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys))); +} + /* ************************************************************************* */ // Check markdown representation with a value formatter. TEST(DiscreteValues, markdownWithValueFormatter) { From 8746b15a4a4188692956054d9b7bf4ad5d95269e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:08:12 -0500 Subject: [PATCH 04/11] Use two new methods --- .../tests/testHybridGaussianConditional.cpp | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 7c7fc3ea5..ba6eaf4cd 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -308,6 +308,14 @@ TEST(HybridGaussianConditional, Prune) { #include +// Helper function to apply discrete values to the tree +auto choose(auto tree, const DiscreteValues &discreteValues) { + for (const auto &[key, value] : discreteValues) { + tree = tree.choose(key, value); + } + return tree; +} + /** * Return a HybridConditional by choosing branches based on the given discrete * values. If all discrete parents are specified, return a HybridConditional @@ -316,31 +324,27 @@ TEST(HybridGaussianConditional, Prune) { HybridConditional::shared_ptr choose( const HybridGaussianConditional::shared_ptr &self, const DiscreteValues &discreteValues) { - const auto &discreteParents = self->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = discreteValues.find(key.first); - if (it != discreteValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); + auto parentValues = discreteValues.filter(self->discreteKeys()); + auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); + + // Case 1: Fully determined, return corresponding Gaussian conditional + if (parentValues.size() == self->discreteKeys().size()) { + return std::make_shared(self->choose(parentValues)); } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - return std::make_shared(self->choose(deadParentValues)); - } else if (liveParents.size() > 0) { + + // Case 2: Some live parents remain, build a new tree + if (!unspecifiedParentKeys.empty()) { auto newTree = self->factors(); - for (auto &&[key, value] : discreteValues) { + for (const auto &[key, value] : parentValues) { newTree = newTree.choose(key, value); } return std::make_shared( - std::make_shared(liveParents, newTree)); - } else { - // Add as-is - return std::make_shared(self); + std::make_shared(unspecifiedParentKeys, + newTree)); } + + // Case 3: No changes needed, return original + return std::make_shared(self); } /* ************************************************************************* */ @@ -356,6 +360,11 @@ TEST(HybridGaussianConditional, PrunePlus) { EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr oneParent2 = + choose(hgc, {{M(7), 0}, {M(1), 0}}); + EXPECT(oneParent2->isHybrid()); + EXPECT(oneParent2->asHybrid()->nrComponents() == 2); + const HybridConditional::shared_ptr gaussian = choose(hgc, {{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian()); From 4d1a8e50572eb9da23ce729e3de592e3ee17db2c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 22:46:33 -0500 Subject: [PATCH 05/11] restrict method --- gtsam/hybrid/HybridConditional.cpp | 37 +++++++++++++ gtsam/hybrid/HybridConditional.h | 8 +++ .../tests/testHybridGaussianConditional.cpp | 53 ++++++------------- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 97ec1a1f8..257eca314 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -169,4 +169,41 @@ double HybridConditional::evaluate(const HybridValues &values) const { return std::exp(logProbability(values)); } +/* ************************************************************************ */ +HybridConditional::shared_ptr HybridConditional::restrict( + const DiscreteValues &discreteValues) const { + if (auto gc = asGaussian()) { + return std::make_shared(gc); + } else if (auto dc = asDiscrete()) { + return std::make_shared(dc); + }; + + auto hgc = asHybrid(); + if (!hgc) + throw std::runtime_error( + "HybridConditional::restrict: conditional type not handled"); + + // Case 1: Fully determined, return corresponding Gaussian conditional + auto parentValues = discreteValues.filter(discreteKeys_); + if (parentValues.size() == discreteKeys_.size()) { + return std::make_shared(hgc->choose(parentValues)); + } + + // Case 2: Some live parents remain, build a new tree + auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); + if (!unspecifiedParentKeys.empty()) { + auto newTree = hgc->factors(); + for (const auto &[key, value] : parentValues) { + newTree = newTree.choose(key, value); + } + return std::make_shared( + std::make_shared(unspecifiedParentKeys, + newTree)); + } + + // Case 3: No changes needed, return original + return std::make_shared(hgc); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3cf5b80e5..075fbe411 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -215,6 +215,14 @@ class GTSAM_EXPORT HybridConditional return true; } + /** + * Return a HybridConditional by choosing branches based on the given discrete + * values. If all discrete parents are specified, return a HybridConditional + * which is just a GaussianConditional. If this conditional is *not* a hybrid + * conditional, just return that. + */ + shared_ptr restrict(const DiscreteValues& discreteValues) const; + /// @} private: diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index ba6eaf4cd..032be5a78 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -316,57 +316,34 @@ auto choose(auto tree, const DiscreteValues &discreteValues) { return tree; } -/** - * Return a HybridConditional by choosing branches based on the given discrete - * values. If all discrete parents are specified, return a HybridConditional - * which is just a GaussianConditional. +/* ************************************************************************* + * This test verifies the behavior of the restrict method in different + * scenarios: + * - When no restrictions are applied. + * - When one parent is restricted. + * - When two parents are restricted. + * - When the restriction results in a Gaussian conditional. */ -HybridConditional::shared_ptr choose( - const HybridGaussianConditional::shared_ptr &self, - const DiscreteValues &discreteValues) { - auto parentValues = discreteValues.filter(self->discreteKeys()); - auto unspecifiedParentKeys = discreteValues.missingKeys(self->discreteKeys()); +TEST(HybridGaussianConditional, Restrict) { + // Create a HybridConditional with two discrete parents P(z0|m0,m1) + const auto hc = + std::make_shared(two_mode_measurement::hgc); - // Case 1: Fully determined, return corresponding Gaussian conditional - if (parentValues.size() == self->discreteKeys().size()) { - return std::make_shared(self->choose(parentValues)); - } - - // Case 2: Some live parents remain, build a new tree - if (!unspecifiedParentKeys.empty()) { - auto newTree = self->factors(); - for (const auto &[key, value] : parentValues) { - newTree = newTree.choose(key, value); - } - return std::make_shared( - std::make_shared(unspecifiedParentKeys, - newTree)); - } - - // Case 3: No changes needed, return original - return std::make_shared(self); -} - -/* ************************************************************************* */ -// Test the pruning and dead-mode removal. -TEST(HybridGaussianConditional, PrunePlus) { - using two_mode_measurement::hgc; // two discrete parents - - const HybridConditional::shared_ptr same = choose(hgc, {}); + const HybridConditional::shared_ptr same = hc->restrict({}); EXPECT(same->isHybrid()); EXPECT(same->asHybrid()->nrComponents() == 4); - const HybridConditional::shared_ptr oneParent = choose(hgc, {{M(1), 0}}); + const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}}); EXPECT(oneParent->isHybrid()); EXPECT(oneParent->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr oneParent2 = - choose(hgc, {{M(7), 0}, {M(1), 0}}); + hc->restrict({{M(7), 0}, {M(1), 0}}); EXPECT(oneParent2->isHybrid()); EXPECT(oneParent2->asHybrid()->nrComponents() == 2); const HybridConditional::shared_ptr gaussian = - choose(hgc, {{M(1), 0}, {M(2), 1}}); + hc->restrict({{M(1), 0}, {M(2), 1}}); EXPECT(gaussian->asGaussian()); } From 05ad198ca6097627469833d3a9080789aa9a55d5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:37:58 -0500 Subject: [PATCH 06/11] Use restrict inside prune --- gtsam/hybrid/HybridBayesNet.cpp | 80 ++++++++------------------------- 1 file changed, 19 insertions(+), 61 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 2efb8030e..a911a047a 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -96,71 +96,36 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - GTSAM_PRINT(deadModesValues); - #if GTSAM_HYBRID_TIMING gttoc_(DeadModeRemoval); #endif } - /* To prune, we visitWith every leaf in the HybridGaussianConditional. - * For each leaf, using the assignment we can check the discrete decision tree - * for 0.0 probability, then just set the leaf to a nullptr. - * - * We can later check the HybridGaussianConditional for just nullptrs. - */ - - // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned discrete joint. + // Go through all the Gaussian conditionals, restrict them according to + // deadModesValues, and then prune further. for (auto &&conditional : *this) { - if (auto hgc = conditional->asHybrid()) { + if (conditional->isDiscrete()) continue; + + // Restrict conditional using deadModesValues. + // No-op if not a HybridGaussianConditional or deadModesValues empty. + auto restricted = conditional->restrict(deadModesValues); + + // Now decide on type what to do: + if (auto hgc = restricted->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { - GTSAM_PRINT(marginal); - GTSAM_PRINT(pruned); throw std::runtime_error( "A HybridGaussianConditional had all its conditionals pruned"); } - - if (deadModeThreshold.has_value()) { - const auto &discreteParents = - prunedHybridGaussianConditional->discreteKeys(); - DiscreteValues deadParentValues; - DiscreteKeys liveParents; - for (const auto &key : discreteParents) { - auto it = deadModesValues.find(key.first); - if (it != deadModesValues.end()) - deadParentValues[key.first] = it->second; - else - liveParents.emplace_back(key); - } - // If so then we just get the corresponding Gaussian conditional: - if (deadParentValues.size() == discreteParents.size()) { - // print on how many discreteParents we are choosing: - result.push_back( - prunedHybridGaussianConditional->choose(deadParentValues)); - } else if (liveParents.size() > 0) { - auto newTree = prunedHybridGaussianConditional->factors(); - for (auto &&[key, value] : deadModesValues) { - newTree = newTree.choose(key, value); - } - result.emplace_shared(liveParents, - newTree); - } else { - // Add as-is - result.push_back(prunedHybridGaussianConditional); - } - } else { - // Type-erase and add to the pruned Bayes Net fragment. - result.push_back(prunedHybridGaussianConditional); - } - - } else if (auto gc = conditional->asGaussian()) { + // Type-erase and add to the pruned Bayes Net fragment. + result.push_back(prunedHybridGaussianConditional); + } else if (auto gc = restricted->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); - } - // We ignore DiscreteConditional as they are already pruned and added. + } else + throw std::runtime_error( + "HybrdiBayesNet::prune: Unknown HybridConditional type."); } #if GTSAM_HYBRID_TIMING @@ -169,21 +134,14 @@ HybridBayesNet HybridBayesNet::prune( if (deadModeThreshold.has_value()) { /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. + If the pruned discrete conditional has any keys left, we add it to the + HybridBayesNet. If not, it means it is an orphan so we don't add this + pruned joint, and instead add only the marginals below. */ if (pruned.keys().size() > 0) { result.emplace_shared(pruned); } - // Add the marginals for future factors - // for (auto &&[key, _] : deadModesValues) { - // result.push_back( - // std::dynamic_pointer_cast(marginals(key))); - // } - } else { result.emplace_shared(pruned); } From e6662b820658e397c85e3e8e3ca4e95d60084a0b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:38:06 -0500 Subject: [PATCH 07/11] Fix unit tests --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 37 ++++++++++++----------- gtsam/hybrid/tests/testHybridSmoother.cpp | 4 +-- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 86dcd48e4..3ddad23ff 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) { // prune auto pruned = bayesNet.prune(1); - CHECK(pruned.at(1)->asHybrid()); - EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); + CHECK(pruned.at(0)->asHybrid()); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); // error @@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) { const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + // First conditional is still the same: P( x0 | x1 m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + + // Check that hybrid conditional that only depend on M1 + // is now Gaussian and not Hybrid + EXPECT(prunedBayesNet.at(1)->isContinuous()); + + // Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + // Check that discrete joint only has M0 and not (M0, M1) // since M0 is removed - KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); - EXPECT(KeyVector{M(0)} == actual_keys); - - // Check that hybrid conditionals that only depend on M1 - // are now Gaussian and not Hybrid - EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isDiscrete()); - EXPECT(prunedBayesNet.at(2)->isHybrid()); - // Only P(X2 | X1, M1) depends on M1, - // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(3)->isContinuous()); - EXPECT(prunedBayesNet.at(4)->isHybrid()); + auto joint = prunedBayesNet.at(3)->asDiscrete(); + EXPECT(joint); + EXPECT(joint->keys() == KeyVector{M(0)}); } /* ****************************************************************************/ @@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) { const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += - prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); pruned_logProbability += - prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues); double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); @@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - CHECK(pruned.at(0)->asDiscrete()); - auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); + CHECK(pruned.at(4)->asDiscrete()); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 3a0f376cc..97a302faf 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) { } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) { } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. From 57bcb9f4e6b54f3086355bce64906fca977da227 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 23:50:43 -0500 Subject: [PATCH 08/11] Add contains --- gtsam/discrete/DiscreteValues.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index 7c73da681..0644e0c16 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -101,6 +101,14 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { */ DiscreteValues& update(const DiscreteValues& values); + /** + * @brief Check if the DiscreteValues contains the given key. + * + * @param key The key to check for. + * @return True if the key is present, false otherwise. + */ + bool contains(Key key) const { return this->find(key) != this->end(); } + /** * @brief Filter values by keys. * From 957c967d0c186b09e61a1bcac40b3661d37b8beb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 00:02:34 -0500 Subject: [PATCH 09/11] remove old remnant --- .../hybrid/tests/testHybridGaussianConditional.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 032be5a78..88a4fa485 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -304,18 +305,6 @@ TEST(HybridGaussianConditional, Prune) { } } -/* ************************************************************************* */ - -#include - -// Helper function to apply discrete values to the tree -auto choose(auto tree, const DiscreteValues &discreteValues) { - for (const auto &[key, value] : discreteValues) { - tree = tree.choose(key, value); - } - return tree; -} - /* ************************************************************************* * This test verifies the behavior of the restrict method in different * scenarios: From 3c10913c7042ee196f2147fa6347826523c62a41 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 07:57:42 -0500 Subject: [PATCH 10/11] DiscreteNet::prune --- gtsam/discrete/DiscreteBayesNet.cpp | 54 +++++++++++++++++ gtsam/discrete/DiscreteBayesNet.h | 12 ++++ gtsam/hybrid/HybridBayesNet.cpp | 89 ++++++----------------------- 3 files changed, 85 insertions(+), 70 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 56265b0a4..7d929f12c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace gtsam { @@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { return result; } +/* ************************************************************************* */ +// The implementation is: build the entire joint into one factor and then prune. +// TODO(Frank): This can be quite expensive *unless* the factors have already +// been pruned before. Another, possibly faster approach is branch and bound +// search to find the K-best leaves and then create a single pruned conditional. +DiscreteBayesNet DiscreteBayesNet::prune( + size_t maxNrLeaves, const std::optional& deadModeThreshold, + DiscreteValues* fixedValues) const { + // Multiply into one big conditional. NOTE: possibly quite expensive. + DiscreteConditional joint; + for (const DiscreteConditional::shared_ptr& conditional : *this) + joint = joint * (*conditional); + + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + pruned.prune(maxNrLeaves); + + DiscreteValues deadModesValues; + // If we have a dead mode threshold and discrete variables left after pruning, + // then we run dead mode removal. + if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); + for (auto dkey : pruned.discreteKeys()) { + const Vector probabilities = marginals.marginalProbabilities(dkey); + + int index = -1; + auto threshold = (probabilities.array() > *deadModeThreshold); + // If atleast 1 value is non-zero, then we can find the index + // Else if all are zero, index would be set to 0 which is incorrect + if (!threshold.isZero()) { + threshold.maxCoeff(&index); + } + + if (index >= 0) { + deadModesValues.emplace(dkey.first, index); + } + } + + // Remove the modes (imperative) + pruned.removeDiscreteModes(deadModesValues); + + // Set the fixed values if requested. + if (fixedValues) { + *fixedValues = deadModesValues; + } + } + + // Return the resulting DiscreteBayesNet. + DiscreteBayesNet result; + if (pruned.keys().size() > 0) result.push_back(pruned); + return result; +} + /* *********************************************************************** */ std::string DiscreteBayesNet::markdown( const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 738b91aa5..01b452865 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { */ DiscreteValues sample(DiscreteValues given) const; + /** + * @brief Prune the Bayes net + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param fixedValues If given, return the fixed values removed. + * @return A new DiscreteBayesNet with pruned conditionals. + */ + DiscreteBayesNet prune(size_t maxNrLeaves, + const std::optional& deadModeThreshold = {}, + DiscreteValues* fixedValues = nullptr) const; + ///@} /// @name Wrapper support /// @{ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index a911a047a..ea4c3e80b 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -43,10 +42,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { } /* ************************************************************************* */ -// The implementation is: build the entire joint into one factor and then prune. -// TODO(Frank): This can be quite expensive *unless* the factors have already -// been pruned before. Another, possibly faster approach is branch and bound -// search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { #if GTSAM_HYBRID_TIMING @@ -55,63 +50,31 @@ HybridBayesNet HybridBayesNet::prune( // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); + // Prune discrete Bayes net + DiscreteValues fixed; + auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + // Multiply into one big conditional. NOTE: possibly quite expensive. - DiscreteConditional joint; - for (auto &&conditional : marginal) { - joint = joint * (*conditional); + DiscreteConditional pruned; + for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); + + // Set the fixed values if requested. + if (deadModeThreshold && fixedValues) { + *fixedValues = fixed; } - // Initialize the resulting HybridBayesNet. HybridBayesNet result; - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - DiscreteConditional pruned = joint; - pruned.prune(maxNrLeaves); - - DiscreteValues deadModesValues; - // If we have a dead mode threshold and discrete variables left after pruning, - // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { -#if GTSAM_HYBRID_TIMING - gttic_(DeadModeRemoval); -#endif - - DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); - for (auto dkey : pruned.discreteKeys()) { - Vector probabilities = marginals.marginalProbabilities(dkey); - - int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); - // If atleast 1 value is non-zero, then we can find the index - // Else if all are zero, index would be set to 0 which is incorrect - if (!threshold.isZero()) { - threshold.maxCoeff(&index); - } - - if (index >= 0) { - deadModesValues.emplace(dkey.first, index); - } - } - - // Remove the modes (imperative) - pruned.removeDiscreteModes(deadModesValues); - -#if GTSAM_HYBRID_TIMING - gttoc_(DeadModeRemoval); -#endif - } - // Go through all the Gaussian conditionals, restrict them according to - // deadModesValues, and then prune further. - for (auto &&conditional : *this) { + // fixed values, and then prune further. + for (std::shared_ptr conditional : *this) { if (conditional->isDiscrete()) continue; - // Restrict conditional using deadModesValues. - // No-op if not a HybridGaussianConditional or deadModesValues empty. - auto restricted = conditional->restrict(deadModesValues); + // No-op if not a HybridGaussianConditional. + if (deadModeThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: - if (auto hgc = restricted->asHybrid()) { + if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); if (!prunedHybridGaussianConditional) { @@ -120,7 +83,7 @@ HybridBayesNet HybridBayesNet::prune( } // Type-erase and add to the pruned Bayes Net fragment. result.push_back(prunedHybridGaussianConditional); - } else if (auto gc = restricted->asGaussian()) { + } else if (auto gc = conditional->asGaussian()) { // Add the non-HybridGaussianConditional conditional result.push_back(gc); } else @@ -128,23 +91,9 @@ HybridBayesNet HybridBayesNet::prune( "HybrdiBayesNet::prune: Unknown HybridConditional type."); } -#if GTSAM_HYBRID_TIMING - gttoc_(HybridPruning); -#endif - - if (deadModeThreshold.has_value()) { - /* - If the pruned discrete conditional has any keys left, we add it to the - HybridBayesNet. If not, it means it is an orphan so we don't add this - pruned joint, and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } - - } else { - result.emplace_shared(pruned); - } + // Add the pruned discrete conditionals to the result. + for (const DiscreteConditional::shared_ptr &discrete : prunedBN) + result.push_back(discrete); return result; } From 9bae03a6fa89b931eff1e9c558ebe682ccd2b050 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 30 Jan 2025 08:57:04 -0500 Subject: [PATCH 11/11] Change threshold name --- gtsam/discrete/DiscreteBayesNet.cpp | 6 +++--- gtsam/discrete/DiscreteBayesNet.h | 4 ++-- gtsam/hybrid/HybridBayesNet.cpp | 9 +++++---- gtsam/hybrid/HybridBayesNet.h | 18 ++++++++++-------- gtsam/hybrid/HybridSmoother.cpp | 2 +- gtsam/hybrid/HybridSmoother.h | 9 +++++---- 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7d929f12c..8c04cb91c 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const { // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. DiscreteBayesNet DiscreteBayesNet::prune( - size_t maxNrLeaves, const std::optional& deadModeThreshold, + size_t maxNrLeaves, const std::optional& marginalThreshold, DiscreteValues* fixedValues) const { // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional joint; @@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune( DiscreteValues deadModesValues; // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. - if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { + if (marginalThreshold.has_value() && pruned.keys().size() > 0) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { const Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > *deadModeThreshold); + auto threshold = (probabilities.array() > *marginalThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 01b452865..eea1739f6 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { * @brief Prune the Bayes net * * @param maxNrLeaves The maximum number of leaves to keep. - * @param deadModeThreshold If given, threshold on marginals to prune variables. + * @param marginalThreshold If given, threshold on marginals to prune variables. * @param fixedValues If given, return the fixed values removed. * @return A new DiscreteBayesNet with pruned conditionals. */ DiscreteBayesNet prune(size_t maxNrLeaves, - const std::optional& deadModeThreshold = {}, + const std::optional& marginalThreshold = {}, DiscreteValues* fixedValues = nullptr) const; ///@} diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ea4c3e80b..5353fe2e0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune( - size_t maxNrLeaves, const std::optional &deadModeThreshold) const { + size_t maxNrLeaves, const std::optional &marginalThreshold, + DiscreteValues *fixedValues) const { #if GTSAM_HYBRID_TIMING gttic_(HybridPruning); #endif @@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune( // Prune discrete Bayes net DiscreteValues fixed; - auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed); + auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed); // Multiply into one big conditional. NOTE: possibly quite expensive. DiscreteConditional pruned; for (auto &&conditional : prunedBN) pruned = pruned * (*conditional); // Set the fixed values if requested. - if (deadModeThreshold && fixedValues) { + if (marginalThreshold && fixedValues) { *fixedValues = fixed; } @@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune( if (conditional->isDiscrete()) continue; // No-op if not a HybridGaussianConditional. - if (deadModeThreshold) conditional = conditional->restrict(fixed); + if (marginalThreshold) conditional = conditional->restrict(fixed); // Now decide on type what to do: if (auto hgc = conditional->asHybrid()) { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fb05e2407..0840cb381 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param deadModeThreshold The threshold to check the mode marginals against. - * If greater than this threshold, the mode gets assigned that value and is - * considered "dead" for hybrid elimination. - * The mode can then be removed since it only has a single possible - * assignment. + * @param marginalThreshold The threshold to check the mode marginals against. + * @param fixedValues The fixed values resulting from dead mode removal. + * + * @note If marginal greater than this threshold, the mode gets assigned that + * value and is considered "dead" for hybrid elimination. The mode can then be + * removed since it only has a single possible assignment. + * @return A pruned HybridBayesNet */ - HybridBayesNet prune( - size_t maxNrLeaves, - const std::optional &deadModeThreshold = {}) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &marginalThreshold = {}, + DiscreteValues *fixedValues = nullptr) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 45320896a..594c12825 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_); } // Add the partial bayes net to the posterior bayes net. diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index c3f022c62..2f7bfcebb 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother { HybridGaussianFactorGraph remainingFactorGraph_; /// The threshold above which we make a decision about a mode. - std::optional deadModeThreshold_; + std::optional marginalThreshold_; + DiscreteValues fixedValues_; public: /** * @brief Constructor * * @param removeDeadModes Flag indicating whether to remove dead modes. - * @param deadModeThreshold The threshold above which a mode gets assigned a + * @param marginalThreshold The threshold above which a mode gets assigned a * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(const std::optional deadModeThreshold = {}) - : deadModeThreshold_(deadModeThreshold) {} + HybridSmoother(const std::optional marginalThreshold = {}) + : marginalThreshold_(marginalThreshold) {} /** * Given new factors, perform an incremental update.