From 1dd3b180b1781a4be2af3f933e4b180d5cca559d Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Mon, 28 Oct 2024 20:33:06 +0800 Subject: [PATCH 1/8] update testIncrementalFixedLagSmoother.cpp to reproduce the bug in marginalization --- gtsam/inference/VariableIndex.h | 6 +- .../tests/testIncrementalFixedLagSmoother.cpp | 97 ++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/gtsam/inference/VariableIndex.h b/gtsam/inference/VariableIndex.h index 207ded0ce..110c0bba4 100644 --- a/gtsam/inference/VariableIndex.h +++ b/gtsam/inference/VariableIndex.h @@ -87,9 +87,11 @@ class GTSAM_EXPORT VariableIndex { const FactorIndices& operator[](Key variable) const { KeyMap::const_iterator item = index_.find(variable); if(item == index_.end()) - throw std::invalid_argument("Requested non-existent variable from VariableIndex"); + throw std::invalid_argument("Requested non-existent variable '" + + DefaultKeyFormatter(variable) + + "' from VariableIndex"); else - return item->second; + return item->second; } /// Return true if no factors associated with a variable diff --git a/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp index 3454c352a..cb4dcbf9a 100644 --- a/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp @@ -49,6 +49,41 @@ bool check_smoother(const NonlinearFactorGraph& fullgraph, const Values& fullini return assert_equal(expected, actual); } +/* ************************************************************************* */ +void PrintSymbolicTreeHelper( + const ISAM2Clique::shared_ptr& clique, const std::string indent = "") { + + // Print the current clique + std::cout << indent << "P( "; + for(Key key: clique->conditional()->frontals()) { + std::cout << DefaultKeyFormatter(key) << " "; + } + if (clique->conditional()->nrParents() > 0) + std::cout << "| "; + for(Key key: clique->conditional()->parents()) { + std::cout << DefaultKeyFormatter(key) << " "; + } + std::cout << ")" << std::endl; + + // Recursively print all of the children + for(const ISAM2Clique::shared_ptr& child: clique->children) { + PrintSymbolicTreeHelper(child, indent + " "); + } +} + +/* ************************************************************************* */ +void PrintSymbolicTree(const ISAM2& isam, + const std::string& label) { + std::cout << label << std::endl; + if (!isam.roots().empty()) { + for(const ISAM2::sharedClique& root: isam.roots()) { + PrintSymbolicTreeHelper(root); + } + } else + std::cout << "{Empty Tree}" << std::endl; +} + + /* ************************************************************************* */ TEST( IncrementalFixedLagSmoother, Example ) { @@ -64,7 +99,7 @@ TEST( IncrementalFixedLagSmoother, Example ) // Create a Fixed-Lag Smoother typedef IncrementalFixedLagSmoother::KeyTimestampMap Timestamps; - IncrementalFixedLagSmoother smoother(7.0, ISAM2Params()); + IncrementalFixedLagSmoother smoother(9.0, ISAM2Params()); // Create containers to keep the full graph Values fullinit; @@ -158,6 +193,9 @@ TEST( IncrementalFixedLagSmoother, Example ) Values newValues; Timestamps newTimestamps; + // Add the odometry factor twice to ensure the removeFactor test below works, + // where we need to keep the connectivity of the graph. + newFactors.push_back(BetweenFactor(key1, key2, Point2(1.0, 0.0), odometerNoise)); newFactors.push_back(BetweenFactor(key1, key2, Point2(1.0, 0.0), odometerNoise)); newValues.insert(key2, Point2(double(i)+0.1, -0.1)); newTimestamps[key2] = double(i); @@ -210,6 +248,10 @@ TEST( IncrementalFixedLagSmoother, Example ) const NonlinearFactorGraph smootherFactorsBeforeRemove = smoother.getFactors(); + std::cout << "fullgraph.size() = " << fullgraph.size() << std::endl; + std::cout << "smootherFactorsBeforeRemove.size() = " + << smootherFactorsBeforeRemove.size() << std::endl; + // remove factor smoother.update(emptyNewFactors, emptyNewValues, emptyNewTimestamps,factorToRemove); @@ -231,6 +273,59 @@ TEST( IncrementalFixedLagSmoother, Example ) } } } + + { + PrintSymbolicTree(smoother.getISAM2(), "Bayes Tree Before marginalization test:"); + + i = 17; + while(i <= 200) { + Key key_0 = MakeKey(i); + Key key_1 = MakeKey(i-1); + Key key_2 = MakeKey(i-2); + Key key_3 = MakeKey(i-3); + Key key_4 = MakeKey(i-4); + Key key_5 = MakeKey(i-5); + Key key_6 = MakeKey(i-6); + Key key_7 = MakeKey(i-7); + Key key_8 = MakeKey(i-8); + + NonlinearFactorGraph newFactors; + Values newValues; + Timestamps newTimestamps; + + // To make a complex graph + newFactors.push_back(BetweenFactor(key_1, key_0, Point2(1.0, 0.0), odometerNoise)); + if (i % 2 == 0) + newFactors.push_back(BetweenFactor(key_2, key_1, Point2(1.0, 0.0), odometerNoise)); + if (i % 3 == 0) + newFactors.push_back(BetweenFactor(key_3, key_2, Point2(1.0, 0.0), odometerNoise)); + if (i % 4 == 0) + newFactors.push_back(BetweenFactor(key_4, key_3, Point2(1.0, 0.0), odometerNoise)); + if (i % 5 == 0) + newFactors.push_back(BetweenFactor(key_5, key_4, Point2(1.0, 0.0), odometerNoise)); + if (i % 6 == 0) + newFactors.push_back(BetweenFactor(key_6, key_5, Point2(1.0, 0.0), odometerNoise)); + if (i % 7 == 0) + newFactors.push_back(BetweenFactor(key_7, key_6, Point2(1.0, 0.0), odometerNoise)); + if (i % 8 == 0) + newFactors.push_back(BetweenFactor(key_8, key_7, Point2(1.0, 0.0), odometerNoise)); + + newValues.insert(key_0, Point2(double(i)+0.1, -0.1)); + newTimestamps[key_0] = double(i); + + fullgraph.push_back(newFactors); + fullinit.insert(newValues); + + // Update the smoother + smoother.update(newFactors, newValues, newTimestamps); + + // Check + CHECK(check_smoother(fullgraph, fullinit, smoother, key_0)); + PrintSymbolicTree(smoother.getISAM2(), "Bayes Tree marginalization test: i = " + std::to_string(i)); + + ++i; + } + } } /* ************************************************************************* */ From 896e52ca27baf823112a10c43d5442318dccd271 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Mon, 28 Oct 2024 23:38:04 +0800 Subject: [PATCH 2/8] Fix marginalization in IncrementalFixedLagSmoother. Add BayesTreeMarginalizationHelper.h and use the new helper to gather the additional keys to re-eliminate when marginalizing variables in IncrementalFixedLagSmoother. --- .../BayesTreeMarginalizationHelper.h | 174 ++++++++++++++++++ .../nonlinear/IncrementalFixedLagSmoother.cpp | 10 + .../tests/testIncrementalFixedLagSmoother.cpp | 15 +- 3 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h new file mode 100644 index 000000000..fe261d10e --- /dev/null +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -0,0 +1,174 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BayesTreeMarginalizationHelper.h + * @brief Helper functions for marginalizing variables from a Bayes Tree. + * + * @author Jeffrey (Zhiwei Wang) + * @date Oct 28, 2024 + */ + +// \callgraph +#pragma once + +#include +#include +#include +#include "gtsam_unstable/dllexport.h" + +namespace gtsam { + +/** + * This class provides helper functions for marginalizing variables from a Bayes Tree. + */ +template +class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper { + +public: + using Clique = typename BayesTree::Clique; + using sharedClique = typename BayesTree::sharedClique; + + /** Get the additional keys that need to be re-eliminated when marginalizing + * the variables in @p marginalizableKeys from the Bayes tree @p bayesTree. + * + * @param[in] bayesTree The Bayes tree to be marginalized. + * @param[in] marginalizableKeys The keys to be marginalized. + * + * + * When marginalizing a variable @f$ \theta @f$ in a Bayes tree, some nodes + * may need to be re-eliminated. The variable to be marginalized should be + * eliminated first. + * + * 1. If @f$ \theta @f$ is already in a leaf node @f$ L @f$, and all other + * frontal variables within @f$ L @f$ are to be marginalized, then this + * node does not need to be re-eliminated; the entire node can be directly + * marginalized. + * + * 2. If @f$ \theta @f$ is in a leaf node @f$ L @f$, but @f$ L @f$ contains + * other frontal variables that do not need to be marginalized: + * a. If all other non-marginalized frontal variables are listed after + * @f$ \theta @f$ (each node contains a frontal list, with variables to + * be eliminated earlier in the list), then node @f$ L @f$ does not + * need to be re-eliminated. + * b. Otherwise, if there are non-marginalized nodes listed before + * @f$ \theta @f$, then node @f$ L @f$ needs to be re-eliminated, and + * correspondingly, all nodes between @f$ L @f$ and the root need to be + * re-eliminated. + * + * 3. If @f$ \theta @f$ is in an intermediate node @f$ M @f$ (non-leaf node), + * but none of @f$ M @f$'s child nodes depend on variable @f$ \theta @f$ + * (they only depend on other variables within @f$ M @f$), then during the + * process of marginalizing @f$ \theta @f$, @f$ M @f$ can be treated as a + * leaf node, and @f$ M @f$ should be processed following the same + * approach as for leaf nodes. + * + * In this case, the original elimination of @f$ \theta @f$ does not + * depend on the elimination results of variables in the child nodes. + * + * 4. If @f$ \theta @f$ is in an intermediate node @f$ M @f$ (non-leaf node), + * and there exist child nodes that depend on variable @f$ \theta @f$, + * then not only does node @f$ M @f$ need to be re-eliminated, but all + * child nodes dependent on @f$ \theta @f$, including descendant nodes + * recursively dependent on @f$ \theta @f$, also need to be re-eliminated. + * + * The frontal variables in child nodes were originally eliminated before + * @f$ \theta @f$ and their elimination results are relied upon by + * @f$ \theta @f$'s elimination. When re-eliminating, they should be + * eliminated after @f$ \theta @f$. + */ + static void gatherAdditionalKeysToReEliminate( + const BayesTree& bayesTree, + const KeyVector& marginalizableKeys, + std::set& additionalKeys) { + const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); + + std::set marginalizableKeySet(marginalizableKeys.begin(), marginalizableKeys.end()); + std::set checkedCliques; + + std::set dependentCliques; + for (const Key& key : marginalizableKeySet) { + sharedClique clique = bayesTree[key]; + if (checkedCliques.count(clique)) { + continue; + } + checkedCliques.insert(clique); + + bool is_leaf = clique->children.empty(); + bool need_reeliminate = false; + bool has_non_marginalizable_ahead = false; + for (Key i: clique->conditional()->frontals()) { + if (marginalizableKeySet.count(i)) { + if (has_non_marginalizable_ahead) { + need_reeliminate = true; + break; + } else { + // Check whether there're child nodes dependent on this key. + for(const sharedClique& child: clique->children) { + if (std::find(child->conditional()->beginParents(), + child->conditional()->endParents(), i) + != child->conditional()->endParents()) { + need_reeliminate = true; + break; + } + } + } + } else { + has_non_marginalizable_ahead = true; + } + } + + if (!need_reeliminate) { + continue; + } else { + // need to re-eliminate this clique and all its children that depend on + // a marginalizable key + for (Key i: clique->conditional()->frontals()) { + additionalKeys.insert(i); + for (const sharedClique& child: clique->children) { + if (!dependentCliques.count(child) && + std::find(child->conditional()->beginParents(), + child->conditional()->endParents(), i) + != child->conditional()->endParents()) { + dependentCliques.insert(child); + } + } + } + } + } + + // Recursively add the dependent keys + while (!dependentCliques.empty()) { + auto begin = dependentCliques.begin(); + sharedClique clique = *begin; + dependentCliques.erase(begin); + + for (Key key : clique->conditional()->frontals()) { + additionalKeys.insert(key); + } + + for (const sharedClique& child: clique->children) { + dependentCliques.insert(child); + } + } + + if (debug) { + std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; + for (const Key& key : additionalKeys) { + std::cout << DefaultKeyFormatter(key) << " "; + } + std::cout << std::endl; + } + } +}; +// BayesTreeMarginalizationHelper + +}/// namespace gtsam diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index 52e56260d..9d27f5713 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -20,11 +20,15 @@ */ #include +#include #include +// #define GTSAM_OLD_MARGINALIZATION + namespace gtsam { /* ************************************************************************* */ +#ifdef GTSAM_OLD_MARGINALIZATION void recursiveMarkAffectedKeys(const Key& key, const ISAM2Clique::shared_ptr& clique, std::set& additionalKeys) { @@ -45,6 +49,7 @@ void recursiveMarkAffectedKeys(const Key& key, } // If the key was not found in the separator/parents, then none of its children can have it either } +#endif /* ************************************************************************* */ void IncrementalFixedLagSmoother::print(const std::string& s, @@ -116,12 +121,17 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( // Mark additional keys between the marginalized keys and the leaves std::set additionalKeys; +#ifdef GTSAM_OLD_MARGINALIZATION for(Key key: marginalizableKeys) { ISAM2Clique::shared_ptr clique = isam_[key]; for(const ISAM2Clique::shared_ptr& child: clique->children) { recursiveMarkAffectedKeys(key, child, additionalKeys); } } +#else + BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( + isam_, marginalizableKeys, additionalKeys); +#endif KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end()); // Update iSAM2 diff --git a/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp index cb4dcbf9a..cd2ba593b 100644 --- a/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/tests/testIncrementalFixedLagSmoother.cpp @@ -99,7 +99,7 @@ TEST( IncrementalFixedLagSmoother, Example ) // Create a Fixed-Lag Smoother typedef IncrementalFixedLagSmoother::KeyTimestampMap Timestamps; - IncrementalFixedLagSmoother smoother(9.0, ISAM2Params()); + IncrementalFixedLagSmoother smoother(12.0, ISAM2Params()); // Create containers to keep the full graph Values fullinit; @@ -226,6 +226,7 @@ TEST( IncrementalFixedLagSmoother, Example ) newFactors.push_back(BetweenFactor(key1, key2, Point2(1.0, 0.0), odometerNoise)); newValues.insert(key2, Point2(double(i)+0.1, -0.1)); newTimestamps[key2] = double(i); + ++i; fullgraph.push_back(newFactors); fullinit.insert(newValues); @@ -275,10 +276,12 @@ TEST( IncrementalFixedLagSmoother, Example ) } { + SETDEBUG("BayesTreeMarginalizationHelper", true); PrintSymbolicTree(smoother.getISAM2(), "Bayes Tree Before marginalization test:"); - i = 17; - while(i <= 200) { + // Do pressure test on marginalization. Enlarge max_i to enhance the test. + const int max_i = 500; + while(i <= max_i) { Key key_0 = MakeKey(i); Key key_1 = MakeKey(i-1); Key key_2 = MakeKey(i-2); @@ -288,6 +291,8 @@ TEST( IncrementalFixedLagSmoother, Example ) Key key_6 = MakeKey(i-6); Key key_7 = MakeKey(i-7); Key key_8 = MakeKey(i-8); + Key key_9 = MakeKey(i-9); + Key key_10 = MakeKey(i-10); NonlinearFactorGraph newFactors; Values newValues; @@ -309,6 +314,10 @@ TEST( IncrementalFixedLagSmoother, Example ) newFactors.push_back(BetweenFactor(key_7, key_6, Point2(1.0, 0.0), odometerNoise)); if (i % 8 == 0) newFactors.push_back(BetweenFactor(key_8, key_7, Point2(1.0, 0.0), odometerNoise)); + if (i % 9 == 0) + newFactors.push_back(BetweenFactor(key_9, key_8, Point2(1.0, 0.0), odometerNoise)); + if (i % 10 == 0) + newFactors.push_back(BetweenFactor(key_10, key_9, Point2(1.0, 0.0), odometerNoise)); newValues.insert(key_0, Point2(double(i)+0.1, -0.1)); newTimestamps[key_0] = double(i); From c6ba2b5fd852e29d16eb1d963b216ea98d3d0ff5 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 29 Oct 2024 15:13:09 +0800 Subject: [PATCH 3/8] update doc string in BayesTreeMarginalizationHelper.h --- .../BayesTreeMarginalizationHelper.h | 64 +++++++------------ 1 file changed, 23 insertions(+), 41 deletions(-) diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index fe261d10e..6f5ff425f 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -37,53 +37,33 @@ public: using Clique = typename BayesTree::Clique; using sharedClique = typename BayesTree::sharedClique; - /** Get the additional keys that need to be re-eliminated when marginalizing + /** Get the additional keys that need reelimination when marginalizing * the variables in @p marginalizableKeys from the Bayes tree @p bayesTree. * - * @param[in] bayesTree The Bayes tree to be marginalized. + * @param[in] bayesTree The Bayes tree. * @param[in] marginalizableKeys The keys to be marginalized. * * - * When marginalizing a variable @f$ \theta @f$ in a Bayes tree, some nodes - * may need to be re-eliminated. The variable to be marginalized should be - * eliminated first. + * When marginalizing a variable @f$ \theta @f$ from a Bayes tree, some + * nodes may need reelimination to ensure the variables to marginalize + * be eliminated first. * - * 1. If @f$ \theta @f$ is already in a leaf node @f$ L @f$, and all other - * frontal variables within @f$ L @f$ are to be marginalized, then this - * node does not need to be re-eliminated; the entire node can be directly - * marginalized. + * We should consider two cases: * - * 2. If @f$ \theta @f$ is in a leaf node @f$ L @f$, but @f$ L @f$ contains - * other frontal variables that do not need to be marginalized: - * a. If all other non-marginalized frontal variables are listed after - * @f$ \theta @f$ (each node contains a frontal list, with variables to - * be eliminated earlier in the list), then node @f$ L @f$ does not - * need to be re-eliminated. - * b. Otherwise, if there are non-marginalized nodes listed before - * @f$ \theta @f$, then node @f$ L @f$ needs to be re-eliminated, and - * correspondingly, all nodes between @f$ L @f$ and the root need to be - * re-eliminated. + * 1. If a child node relies on @f$ \theta @f$ (i.e., @f$ \theta @f$ + * is a parent / separator of the node), then the frontal + * variables of the child node need to be reeliminated. In + * addition, all the descendants of the child node also need to + * be reeliminated. * - * 3. If @f$ \theta @f$ is in an intermediate node @f$ M @f$ (non-leaf node), - * but none of @f$ M @f$'s child nodes depend on variable @f$ \theta @f$ - * (they only depend on other variables within @f$ M @f$), then during the - * process of marginalizing @f$ \theta @f$, @f$ M @f$ can be treated as a - * leaf node, and @f$ M @f$ should be processed following the same - * approach as for leaf nodes. + * 2. If other frontal variables in the same node with @f$ \theta @f$ + * are in front of @f$ \theta @f$ but not to be marginalized, then + * these variables also need to be reeliminated. * - * In this case, the original elimination of @f$ \theta @f$ does not - * depend on the elimination results of variables in the child nodes. + * These variables were eliminated before @f$ \theta @f$ in the original + * Bayes tree, and after reelimination they will be eliminated after + * @f$ \theta @f$ so that @f$ \theta @f$ can be marginalized safely. * - * 4. If @f$ \theta @f$ is in an intermediate node @f$ M @f$ (non-leaf node), - * and there exist child nodes that depend on variable @f$ \theta @f$, - * then not only does node @f$ M @f$ need to be re-eliminated, but all - * child nodes dependent on @f$ \theta @f$, including descendant nodes - * recursively dependent on @f$ \theta @f$, also need to be re-eliminated. - * - * The frontal variables in child nodes were originally eliminated before - * @f$ \theta @f$ and their elimination results are relied upon by - * @f$ \theta @f$'s elimination. When re-eliminating, they should be - * eliminated after @f$ \theta @f$. */ static void gatherAdditionalKeysToReEliminate( const BayesTree& bayesTree, @@ -102,20 +82,21 @@ public: } checkedCliques.insert(clique); - bool is_leaf = clique->children.empty(); bool need_reeliminate = false; bool has_non_marginalizable_ahead = false; for (Key i: clique->conditional()->frontals()) { if (marginalizableKeySet.count(i)) { if (has_non_marginalizable_ahead) { + // Case 2 in the docstring need_reeliminate = true; break; } else { - // Check whether there're child nodes dependent on this key. + // Check whether there's a child node dependent on this key. for(const sharedClique& child: clique->children) { if (std::find(child->conditional()->beginParents(), child->conditional()->endParents(), i) != child->conditional()->endParents()) { + // Case 1 in the docstring need_reeliminate = true; break; } @@ -127,10 +108,11 @@ public: } if (!need_reeliminate) { + // No variable needs to be reeliminated continue; } else { - // need to re-eliminate this clique and all its children that depend on - // a marginalizable key + // Need to reeliminate the current clique and all its children + // that rely on a marginalizable key. for (Key i: clique->conditional()->frontals()) { additionalKeys.insert(i); for (const sharedClique& child: clique->children) { From 67b0b78ea1eacdb2647fedab828d9dcff99aac10 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 31 Oct 2024 18:58:49 +0800 Subject: [PATCH 4/8] Update BayesTreeMarginalizationHelper: 1. Refactor code in BayesTreeMarginalizationHelper; 2. And avoid the unnecessary re-elimination of subtrees that only contain marginalizable variables; --- .../BayesTreeMarginalizationHelper.h | 350 +++++++++++++----- .../nonlinear/IncrementalFixedLagSmoother.cpp | 7 +- 2 files changed, 265 insertions(+), 92 deletions(-) diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 6f5ff425f..04e4e0cac 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -20,6 +20,7 @@ // \callgraph #pragma once +#include #include #include #include @@ -37,109 +38,54 @@ public: using Clique = typename BayesTree::Clique; using sharedClique = typename BayesTree::sharedClique; - /** Get the additional keys that need reelimination when marginalizing - * the variables in @p marginalizableKeys from the Bayes tree @p bayesTree. + /** + * This function identifies variables that need to be re-eliminated before + * performing marginalization. * - * @param[in] bayesTree The Bayes tree. - * @param[in] marginalizableKeys The keys to be marginalized. + * Re-elimination is necessary for a clique containing marginalizable + * variables if: * + * 1. Some non-marginalizable variables appear before marginalizable ones + * in that clique; + * 2. Or it has a child node depending on a marginalizable variable AND the + * subtree rooted at that child contains non-marginalizables. * - * When marginalizing a variable @f$ \theta @f$ from a Bayes tree, some - * nodes may need reelimination to ensure the variables to marginalize - * be eliminated first. - * - * We should consider two cases: - * - * 1. If a child node relies on @f$ \theta @f$ (i.e., @f$ \theta @f$ - * is a parent / separator of the node), then the frontal - * variables of the child node need to be reeliminated. In - * addition, all the descendants of the child node also need to - * be reeliminated. - * - * 2. If other frontal variables in the same node with @f$ \theta @f$ - * are in front of @f$ \theta @f$ but not to be marginalized, then - * these variables also need to be reeliminated. - * - * These variables were eliminated before @f$ \theta @f$ in the original - * Bayes tree, and after reelimination they will be eliminated after - * @f$ \theta @f$ so that @f$ \theta @f$ can be marginalized safely. + * In addition, the subtrees under the aforementioned cliques that require + * re-elimination, which contain non-marginalizable variables in their root + * node, also need to be re-eliminated. * + * @param[in] bayesTree The Bayes tree + * @param[in] marginalizableKeys Keys to be marginalized + * @return Set of additional keys that need to be re-eliminated */ - static void gatherAdditionalKeysToReEliminate( + static std::set gatherAdditionalKeysToReEliminate( const BayesTree& bayesTree, - const KeyVector& marginalizableKeys, - std::set& additionalKeys) { + const KeyVector& marginalizableKeys) { const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); - std::set marginalizableKeySet(marginalizableKeys.begin(), marginalizableKeys.end()); - std::set checkedCliques; + std::set additionalKeys; + std::set marginalizableKeySet( + marginalizableKeys.begin(), marginalizableKeys.end()); + std::set dependentSubtrees; + CachedSearch cachedSearch; - std::set dependentCliques; - for (const Key& key : marginalizableKeySet) { - sharedClique clique = bayesTree[key]; - if (checkedCliques.count(clique)) { - continue; - } - checkedCliques.insert(clique); + // Check each clique that contains a marginalizable key + for (const sharedClique& clique : + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { - bool need_reeliminate = false; - bool has_non_marginalizable_ahead = false; - for (Key i: clique->conditional()->frontals()) { - if (marginalizableKeySet.count(i)) { - if (has_non_marginalizable_ahead) { - // Case 2 in the docstring - need_reeliminate = true; - break; - } else { - // Check whether there's a child node dependent on this key. - for(const sharedClique& child: clique->children) { - if (std::find(child->conditional()->beginParents(), - child->conditional()->endParents(), i) - != child->conditional()->endParents()) { - // Case 1 in the docstring - need_reeliminate = true; - break; - } - } - } - } else { - has_non_marginalizable_ahead = true; - } - } + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { + // Add frontal variables from current clique + addCliqueToKeySet(clique, &additionalKeys); - if (!need_reeliminate) { - // No variable needs to be reeliminated - continue; - } else { - // Need to reeliminate the current clique and all its children - // that rely on a marginalizable key. - for (Key i: clique->conditional()->frontals()) { - additionalKeys.insert(i); - for (const sharedClique& child: clique->children) { - if (!dependentCliques.count(child) && - std::find(child->conditional()->beginParents(), - child->conditional()->endParents(), i) - != child->conditional()->endParents()) { - dependentCliques.insert(child); - } - } - } + // Then gather dependent subtrees to be added later + gatherDependentSubtrees( + clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); } } - // Recursively add the dependent keys - while (!dependentCliques.empty()) { - auto begin = dependentCliques.begin(); - sharedClique clique = *begin; - dependentCliques.erase(begin); - - for (Key key : clique->conditional()->frontals()) { - additionalKeys.insert(key); - } - - for (const sharedClique& child: clique->children) { - dependentCliques.insert(child); - } + // Add the remaining dependent cliques + for (const sharedClique& subtree : dependentSubtrees) { + addSubtreeToKeySet(subtree, &additionalKeys); } if (debug) { @@ -149,6 +95,232 @@ public: } std::cout << std::endl; } + + return additionalKeys; + } + + protected: + + /** + * Gather the cliques containing any of the given keys. + * + * @param[in] bayesTree The Bayes tree + * @param[in] keysOfInterest Set of keys of interest + * @return Set of cliques that contain any of the given keys + */ + static std::set getCliquesContainingKeys( + const BayesTree& bayesTree, + const std::set& keysOfInterest) { + std::set cliques; + for (const Key& key : keysOfInterest) { + cliques.insert(bayesTree[key]); + } + return cliques; + } + + /** + * A struct to cache the results of the below two functions. + */ + struct CachedSearch { + std::unordered_map wholeMarginalizableCliques; + std::unordered_map wholeMarginalizableSubtrees; + }; + + /** + * Check if all variables in the clique are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeCliqueMarginalizable( + const sharedClique& clique, + const std::set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableCliques.find(clique.get()); + if (it != cache->wholeMarginalizableCliques.end()) { + return it->second; + } else { + bool ret = true; + for (Key key : clique->conditional()->frontals()) { + if (!marginalizableKeys.count(key)) { + ret = false; + break; + } + } + cache->wholeMarginalizableCliques.insert({clique.get(), ret}); + return ret; + } + } + + /** + * Check if all variables in the subtree are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeSubtreeMarginalizable( + const sharedClique& subtree, + const std::set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); + if (it != cache->wholeMarginalizableSubtrees.end()) { + return it->second; + } else { + bool ret = true; + if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { + for (const sharedClique& child : subtree->children) { + if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + ret = false; + break; + } + } + } else { + ret = false; + } + cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); + return ret; + } + } + + /** + * Check if a clique contains variables that need reelimination due to + * elimination ordering conflicts. + * + * @param[in] clique The clique to check + * @param[in] marginalizableKeys Set of keys to be marginalized + * @return true if any variables in the clique need re-elimination + */ + static bool needsReelimination( + const sharedClique& clique, + const std::set& marginalizableKeys, + CachedSearch* cache) { + bool hasNonMarginalizableAhead = false; + + // Check each frontal variable in order + for (Key key : clique->conditional()->frontals()) { + if (marginalizableKeys.count(key)) { + // If we've seen non-marginalizable variables before this one, + // we need to reeliminate + if (hasNonMarginalizableAhead) { + return true; + } + + // Check if any child depends on this marginalizable key and the + // subtree rooted at that child contains non-marginalizables. + for (const sharedClique& child : clique->children) { + if (hasDependency(child, key) && + !isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + return true; + } + } + } else { + hasNonMarginalizableAhead = true; + } + } + return false; + } + + /** + * Gather all subtrees that depend on a marginalizable key and contain + * non-marginalizable variables in their root. + * + * @param[in] rootClique The starting clique + * @param[in] marginalizableKeys Set of keys to be marginalized + * @param[out] dependentSubtrees Pointer to set storing dependent cliques + */ + static void gatherDependentSubtrees( + const sharedClique& rootClique, + const std::set& marginalizableKeys, + std::set* dependentSubtrees, + CachedSearch* cache) { + for (Key key : rootClique->conditional()->frontals()) { + if (marginalizableKeys.count(key)) { + // Find children that depend on this key + for (const sharedClique& child : rootClique->children) { + if (!dependentSubtrees->count(child) && + hasDependency(child, key)) { + getSubtreesContainingNonMarginalizables( + child, marginalizableKeys, cache, dependentSubtrees); + } + } + } + } + } + + /** + * Gather all subtrees that contain non-marginalizable variables in its root. + */ + static void getSubtreesContainingNonMarginalizables( + const sharedClique& rootClique, + const std::set& marginalizableKeys, + CachedSearch* cache, + std::set* subtreesContainingNonMarginalizables) { + // If the root clique itself contains non-marginalizable variables, we + // just add it to subtreesContainingNonMarginalizables; + if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { + subtreesContainingNonMarginalizables->insert(rootClique); + return; + } + + // Otherwise, we need to recursively check the children + for (const sharedClique& child : rootClique->children) { + getSubtreesContainingNonMarginalizables( + child, marginalizableKeys, cache, + subtreesContainingNonMarginalizables); + } + } + + /** + * Add all frontal variables from a clique to a key set. + * + * @param[in] clique Clique to add keys from + * @param[out] additionalKeys Pointer to the output key set + */ + static void addCliqueToKeySet( + const sharedClique& clique, + std::set* additionalKeys) { + for (Key key : clique->conditional()->frontals()) { + additionalKeys->insert(key); + } + } + + /** + * Add all frontal variables from a subtree to a key set. + * + * @param[in] subRoot Root clique of the subtree + * @param[out] additionalKeys Pointer to the output key set + */ + static void addSubtreeToKeySet( + const sharedClique& subRoot, + std::set* additionalKeys) { + std::set cliques; + cliques.insert(subRoot); + while(!cliques.empty()) { + auto begin = cliques.begin(); + sharedClique clique = *begin; + cliques.erase(begin); + addCliqueToKeySet(clique, additionalKeys); + for (const sharedClique& child : clique->children) { + cliques.insert(child); + } + } + } + + /** + * Check if the clique depends on the given key. + * + * @param[in] clique Clique to check + * @param[in] key Key to check for dependencies + * @return true if clique depends on the key + */ + static bool hasDependency( + const sharedClique& clique, Key key) { + auto conditional = clique->conditional(); + if (std::find(conditional->beginParents(), + conditional->endParents(), key) + != conditional->endParents()) { + return true; + } else { + return false; + } } }; // BayesTreeMarginalizationHelper diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index 9d27f5713..afe4fb3de 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -120,8 +120,8 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( } // Mark additional keys between the marginalized keys and the leaves - std::set additionalKeys; #ifdef GTSAM_OLD_MARGINALIZATION + std::set additionalKeys; for(Key key: marginalizableKeys) { ISAM2Clique::shared_ptr clique = isam_[key]; for(const ISAM2Clique::shared_ptr& child: clique->children) { @@ -129,8 +129,9 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( } } #else - BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( - isam_, marginalizableKeys, additionalKeys); + std::set additionalKeys = + BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( + isam_, marginalizableKeys); #endif KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end()); From 14c3467520851a0520b4a4ce8f8eadb806c4e82d Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 31 Oct 2024 19:05:54 +0800 Subject: [PATCH 5/8] Remove old marginalization code in IncrementalFixedLagSmoother.cpp --- .../nonlinear/IncrementalFixedLagSmoother.cpp | 37 ------------------- 1 file changed, 37 deletions(-) diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index afe4fb3de..238ff6b3d 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -23,34 +23,8 @@ #include #include -// #define GTSAM_OLD_MARGINALIZATION - namespace gtsam { -/* ************************************************************************* */ -#ifdef GTSAM_OLD_MARGINALIZATION -void recursiveMarkAffectedKeys(const Key& key, - const ISAM2Clique::shared_ptr& clique, std::set& additionalKeys) { - - // Check if the separator keys of the current clique contain the specified key - if (std::find(clique->conditional()->beginParents(), - clique->conditional()->endParents(), key) - != clique->conditional()->endParents()) { - - // Mark the frontal keys of the current clique - for(Key i: clique->conditional()->frontals()) { - additionalKeys.insert(i); - } - - // Recursively mark all of the children - for(const ISAM2Clique::shared_ptr& child: clique->children) { - recursiveMarkAffectedKeys(key, child, additionalKeys); - } - } - // If the key was not found in the separator/parents, then none of its children can have it either -} -#endif - /* ************************************************************************* */ void IncrementalFixedLagSmoother::print(const std::string& s, const KeyFormatter& keyFormatter) const { @@ -119,20 +93,9 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( std::cout << std::endl; } - // Mark additional keys between the marginalized keys and the leaves -#ifdef GTSAM_OLD_MARGINALIZATION - std::set additionalKeys; - for(Key key: marginalizableKeys) { - ISAM2Clique::shared_ptr clique = isam_[key]; - for(const ISAM2Clique::shared_ptr& child: clique->children) { - recursiveMarkAffectedKeys(key, child, additionalKeys); - } - } -#else std::set additionalKeys = BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( isam_, marginalizableKeys); -#endif KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end()); // Update iSAM2 From 1a5e711f0e49451b51b2e54757f3b894b8e4ee50 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 31 Oct 2024 21:52:45 +0800 Subject: [PATCH 6/8] Further optimize the implementation of BayesTreeMarginalizationHelper: Now we won't re-emilinate any unnecessary nodes (we re-emilinated whole subtrees in the previous commits, which is not optimal) --- .../BayesTreeMarginalizationHelper.h | 124 ++++++++---------- 1 file changed, 58 insertions(+), 66 deletions(-) diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 04e4e0cac..724814fd0 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -50,9 +51,12 @@ public: * 2. Or it has a child node depending on a marginalizable variable AND the * subtree rooted at that child contains non-marginalizables. * - * In addition, the subtrees under the aforementioned cliques that require - * re-elimination, which contain non-marginalizable variables in their root - * node, also need to be re-eliminated. + * In addition, for any descendant node depending on a marginalizable + * variable, if the subtree rooted at that descendant contains + * non-marginalizable variables (i.e., it lies on a path from one of the + * aforementioned cliques that require re-elimination to a node containing + * non-marginalizable variables at the leaf side), then it also needs to + * be re-eliminated. * * @param[in] bayesTree The Bayes tree * @param[in] marginalizableKeys Keys to be marginalized @@ -66,7 +70,7 @@ public: std::set additionalKeys; std::set marginalizableKeySet( marginalizableKeys.begin(), marginalizableKeys.end()); - std::set dependentSubtrees; + std::set dependentCliques; CachedSearch cachedSearch; // Check each clique that contains a marginalizable key @@ -77,17 +81,14 @@ public: // Add frontal variables from current clique addCliqueToKeySet(clique, &additionalKeys); - // Then gather dependent subtrees to be added later - gatherDependentSubtrees( - clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); + // Then add the dependent cliques + for (const sharedClique& dependent : + gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { + addCliqueToKeySet(dependent, &additionalKeys); + } } } - // Add the remaining dependent cliques - for (const sharedClique& subtree : dependentSubtrees) { - addSubtreeToKeySet(subtree, &additionalKeys); - } - if (debug) { std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; for (const Key& key : additionalKeys) { @@ -219,53 +220,53 @@ public: } /** - * Gather all subtrees that depend on a marginalizable key and contain - * non-marginalizable variables in their root. + * Gather all dependent nodes that lie on a path from the root clique + * to a clique containing a non-marginalizable variable at the leaf side. * - * @param[in] rootClique The starting clique + * @param[in] rootClique The root clique * @param[in] marginalizableKeys Set of keys to be marginalized - * @param[out] dependentSubtrees Pointer to set storing dependent cliques */ - static void gatherDependentSubtrees( + static std::set gatherDependentCliques( const sharedClique& rootClique, const std::set& marginalizableKeys, - std::set* dependentSubtrees, CachedSearch* cache) { - for (Key key : rootClique->conditional()->frontals()) { - if (marginalizableKeys.count(key)) { - // Find children that depend on this key - for (const sharedClique& child : rootClique->children) { - if (!dependentSubtrees->count(child) && - hasDependency(child, key)) { - getSubtreesContainingNonMarginalizables( - child, marginalizableKeys, cache, dependentSubtrees); - } - } + std::vector dependentChildren; + dependentChildren.reserve(rootClique->children.size()); + for (const sharedClique& child : rootClique->children) { + if (hasDependency(child, marginalizableKeys)) { + dependentChildren.push_back(child); } } + return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); } /** - * Gather all subtrees that contain non-marginalizable variables in its root. + * A helper function for the above gatherDependentCliques(). */ - static void getSubtreesContainingNonMarginalizables( - const sharedClique& rootClique, + static std::set gatherDependentCliquesFromChildren( + const std::vector& dependentChildren, const std::set& marginalizableKeys, - CachedSearch* cache, - std::set* subtreesContainingNonMarginalizables) { - // If the root clique itself contains non-marginalizable variables, we - // just add it to subtreesContainingNonMarginalizables; - if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { - subtreesContainingNonMarginalizables->insert(rootClique); - return; - } + CachedSearch* cache) { + std::deque descendants( + dependentChildren.begin(), dependentChildren.end()); + std::set dependentCliques; + while (!descendants.empty()) { + sharedClique descendant = descendants.front(); + descendants.pop_front(); - // Otherwise, we need to recursively check the children - for (const sharedClique& child : rootClique->children) { - getSubtreesContainingNonMarginalizables( - child, marginalizableKeys, cache, - subtreesContainingNonMarginalizables); + // If the subtree rooted at this descendant contains non-marginalizables, + // it must lie on a path from the root clique to a clique containing + // non-marginalizables at the leaf side. + if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { + dependentCliques.insert(descendant); + } + + // Add all children of the current descendant to the set descendants. + for (const sharedClique& child : descendant->children) { + descendants.push_back(child); + } } + return dependentCliques; } /** @@ -282,28 +283,6 @@ public: } } - /** - * Add all frontal variables from a subtree to a key set. - * - * @param[in] subRoot Root clique of the subtree - * @param[out] additionalKeys Pointer to the output key set - */ - static void addSubtreeToKeySet( - const sharedClique& subRoot, - std::set* additionalKeys) { - std::set cliques; - cliques.insert(subRoot); - while(!cliques.empty()) { - auto begin = cliques.begin(); - sharedClique clique = *begin; - cliques.erase(begin); - addCliqueToKeySet(clique, additionalKeys); - for (const sharedClique& child : clique->children) { - cliques.insert(child); - } - } - } - /** * Check if the clique depends on the given key. * @@ -322,6 +301,19 @@ public: return false; } } + + /** + * Check if the clique depends on any of the given keys. + */ + static bool hasDependency( + const sharedClique& clique, const std::set& keys) { + for (Key key : keys) { + if (hasDependency(clique, key)) { + return true; + } + } + return false; + } }; // BayesTreeMarginalizationHelper From 0d9c3a99583032e25e36b6135084aaa75159430c Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 31 Oct 2024 22:23:22 +0800 Subject: [PATCH 7/8] Remove unused variable --- gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h | 1 - 1 file changed, 1 deletion(-) diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 724814fd0..53d030624 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -70,7 +70,6 @@ public: std::set additionalKeys; std::set marginalizableKeySet( marginalizableKeys.begin(), marginalizableKeys.end()); - std::set dependentCliques; CachedSearch cachedSearch; // Check each clique that contains a marginalizable key From 06dac43cae843e5a2bc16447af9d3bde13e0f54d Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Sat, 2 Nov 2024 17:14:01 +0800 Subject: [PATCH 8/8] Some refinement in BayesTreeMarginalizationHelper: 1. Skip subtrees that have already been visited when searching for dependent cliques; 2. Avoid copying shared_ptrs (which needs extra expensive atomic operations) in the searching. Use const Clique* instead of sharedClique whenever possible; 3. Use std::unordered_set instead of std::set to improve average searching speed. --- .../BayesTreeMarginalizationHelper.h | 167 +++++++++++------- .../nonlinear/IncrementalFixedLagSmoother.cpp | 2 +- 2 files changed, 104 insertions(+), 65 deletions(-) diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 53d030624..1e367e55c 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -62,30 +63,18 @@ public: * @param[in] marginalizableKeys Keys to be marginalized * @return Set of additional keys that need to be re-eliminated */ - static std::set gatherAdditionalKeysToReEliminate( + static std::unordered_set + gatherAdditionalKeysToReEliminate( const BayesTree& bayesTree, const KeyVector& marginalizableKeys) { const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); - std::set additionalKeys; - std::set marginalizableKeySet( - marginalizableKeys.begin(), marginalizableKeys.end()); - CachedSearch cachedSearch; + std::unordered_set additionalCliques = + gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); - // Check each clique that contains a marginalizable key - for (const sharedClique& clique : - getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { - - if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { - // Add frontal variables from current clique - addCliqueToKeySet(clique, &additionalKeys); - - // Then add the dependent cliques - for (const sharedClique& dependent : - gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { - addCliqueToKeySet(dependent, &additionalKeys); - } - } + std::unordered_set additionalKeys; + for (const Clique* clique : additionalCliques) { + addCliqueToKeySet(clique, &additionalKeys); } if (debug) { @@ -100,6 +89,41 @@ public: } protected: + /** + * This function identifies cliques that need to be re-eliminated before + * performing marginalization. + * See the docstring of @ref gatherAdditionalKeysToReEliminate(). + */ + static std::unordered_set + gatherAdditionalCliquesToReEliminate( + const BayesTree& bayesTree, + const KeyVector& marginalizableKeys) { + std::unordered_set additionalCliques; + std::unordered_set marginalizableKeySet( + marginalizableKeys.begin(), marginalizableKeys.end()); + CachedSearch cachedSearch; + + // Check each clique that contains a marginalizable key + for (const Clique* clique : + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { + if (additionalCliques.count(clique)) { + // The clique has already been visited. This can happen when an + // ancestor of the current clique also contain some marginalizable + // varaibles and it's processed beore the current. + continue; + } + + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { + // Add the current clique + additionalCliques.insert(clique); + + // Then add the dependent cliques + gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, + &cachedSearch); + } + } + return additionalCliques; + } /** * Gather the cliques containing any of the given keys. @@ -108,12 +132,12 @@ public: * @param[in] keysOfInterest Set of keys of interest * @return Set of cliques that contain any of the given keys */ - static std::set getCliquesContainingKeys( + static std::unordered_set getCliquesContainingKeys( const BayesTree& bayesTree, - const std::set& keysOfInterest) { - std::set cliques; + const std::unordered_set& keysOfInterest) { + std::unordered_set cliques; for (const Key& key : keysOfInterest) { - cliques.insert(bayesTree[key]); + cliques.insert(bayesTree[key].get()); } return cliques; } @@ -122,8 +146,8 @@ public: * A struct to cache the results of the below two functions. */ struct CachedSearch { - std::unordered_map wholeMarginalizableCliques; - std::unordered_map wholeMarginalizableSubtrees; + std::unordered_map wholeMarginalizableCliques; + std::unordered_map wholeMarginalizableSubtrees; }; /** @@ -132,10 +156,10 @@ public: * Note we use a cache map to avoid repeated searches. */ static bool isWholeCliqueMarginalizable( - const sharedClique& clique, - const std::set& marginalizableKeys, + const Clique* clique, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { - auto it = cache->wholeMarginalizableCliques.find(clique.get()); + auto it = cache->wholeMarginalizableCliques.find(clique); if (it != cache->wholeMarginalizableCliques.end()) { return it->second; } else { @@ -146,7 +170,7 @@ public: break; } } - cache->wholeMarginalizableCliques.insert({clique.get(), ret}); + cache->wholeMarginalizableCliques.insert({clique, ret}); return ret; } } @@ -157,17 +181,17 @@ public: * Note we use a cache map to avoid repeated searches. */ static bool isWholeSubtreeMarginalizable( - const sharedClique& subtree, - const std::set& marginalizableKeys, + const Clique* subtree, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { - auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); + auto it = cache->wholeMarginalizableSubtrees.find(subtree); if (it != cache->wholeMarginalizableSubtrees.end()) { return it->second; } else { bool ret = true; if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { for (const sharedClique& child : subtree->children) { - if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { ret = false; break; } @@ -175,7 +199,7 @@ public: } else { ret = false; } - cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); + cache->wholeMarginalizableSubtrees.insert({subtree, ret}); return ret; } } @@ -189,8 +213,8 @@ public: * @return true if any variables in the clique need re-elimination */ static bool needsReelimination( - const sharedClique& clique, - const std::set& marginalizableKeys, + const Clique* clique, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { bool hasNonMarginalizableAhead = false; @@ -206,8 +230,8 @@ public: // Check if any child depends on this marginalizable key and the // subtree rooted at that child contains non-marginalizables. for (const sharedClique& child : clique->children) { - if (hasDependency(child, key) && - !isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + if (hasDependency(child.get(), key) && + !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { return true; } } @@ -225,47 +249,59 @@ public: * @param[in] rootClique The root clique * @param[in] marginalizableKeys Set of keys to be marginalized */ - static std::set gatherDependentCliques( - const sharedClique& rootClique, - const std::set& marginalizableKeys, + static void gatherDependentCliques( + const Clique* rootClique, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, CachedSearch* cache) { - std::vector dependentChildren; + std::vector dependentChildren; dependentChildren.reserve(rootClique->children.size()); for (const sharedClique& child : rootClique->children) { - if (hasDependency(child, marginalizableKeys)) { - dependentChildren.push_back(child); + if (additionalCliques->count(child.get())) { + // This child has already been visited. This can happen if the + // child itself contains a marginalizable variable and it's + // processed before the current rootClique. + continue; + } + if (hasDependency(child.get(), marginalizableKeys)) { + dependentChildren.push_back(child.get()); } } - return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); + gatherDependentCliquesFromChildren( + dependentChildren, marginalizableKeys, additionalCliques, cache); } /** * A helper function for the above gatherDependentCliques(). */ - static std::set gatherDependentCliquesFromChildren( - const std::vector& dependentChildren, - const std::set& marginalizableKeys, + static void gatherDependentCliquesFromChildren( + const std::vector& dependentChildren, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, CachedSearch* cache) { - std::deque descendants( + std::deque descendants( dependentChildren.begin(), dependentChildren.end()); - std::set dependentCliques; while (!descendants.empty()) { - sharedClique descendant = descendants.front(); + const Clique* descendant = descendants.front(); descendants.pop_front(); // If the subtree rooted at this descendant contains non-marginalizables, // it must lie on a path from the root clique to a clique containing // non-marginalizables at the leaf side. if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { - dependentCliques.insert(descendant); - } + additionalCliques->insert(descendant); - // Add all children of the current descendant to the set descendants. - for (const sharedClique& child : descendant->children) { - descendants.push_back(child); + // Add children of the current descendant to the set descendants. + for (const sharedClique& child : descendant->children) { + if (additionalCliques->count(child.get())) { + // This child has already been visited. + continue; + } else { + descendants.push_back(child.get()); + } + } } } - return dependentCliques; } /** @@ -275,8 +311,8 @@ public: * @param[out] additionalKeys Pointer to the output key set */ static void addCliqueToKeySet( - const sharedClique& clique, - std::set* additionalKeys) { + const Clique* clique, + std::unordered_set* additionalKeys) { for (Key key : clique->conditional()->frontals()) { additionalKeys->insert(key); } @@ -290,8 +326,8 @@ public: * @return true if clique depends on the key */ static bool hasDependency( - const sharedClique& clique, Key key) { - auto conditional = clique->conditional(); + const Clique* clique, Key key) { + auto& conditional = clique->conditional(); if (std::find(conditional->beginParents(), conditional->endParents(), key) != conditional->endParents()) { @@ -305,12 +341,15 @@ public: * Check if the clique depends on any of the given keys. */ static bool hasDependency( - const sharedClique& clique, const std::set& keys) { - for (Key key : keys) { - if (hasDependency(clique, key)) { + const Clique* clique, const std::unordered_set& keys) { + auto& conditional = clique->conditional(); + for (auto it = conditional->beginParents(); + it != conditional->endParents(); ++it) { + if (keys.count(*it)) { return true; } } + return false; } }; diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index 238ff6b3d..0cd9ecbac 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( std::cout << std::endl; } - std::set additionalKeys = + std::unordered_set additionalKeys = BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( isam_, marginalizableKeys); KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());