From a832a52f375aead9c72ac0fc56f4de89b7c679cd Mon Sep 17 00:00:00 2001 From: p-zach Date: Sun, 6 Apr 2025 15:35:04 -0400 Subject: [PATCH] Move bayes tree marg helper (no .i presence) --- .../BayesTreeMarginalizationHelper.h | 358 ++++++++++++++++++ .../nonlinear/IncrementalFixedLagSmoother.cpp | 2 +- .../BayesTreeMarginalizationHelper.h | 349 +---------------- 3 files changed, 365 insertions(+), 344 deletions(-) create mode 100644 gtsam/nonlinear/BayesTreeMarginalizationHelper.h diff --git a/gtsam/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam/nonlinear/BayesTreeMarginalizationHelper.h new file mode 100644 index 000000000..133de6c46 --- /dev/null +++ b/gtsam/nonlinear/BayesTreeMarginalizationHelper.h @@ -0,0 +1,358 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file BayesTreeMarginalizationHelper.h + * @brief Helper functions for marginalizing variables from a Bayes Tree. + * + * @author Jeffrey (Zhiwei Wang) + * @date Oct 28, 2024 + */ + +// \callgraph +#pragma once + +#include +#include +#include +#include +#include +#include +#include "gtsam/dllexport.h" + +namespace gtsam { + +/** + * This class provides helper functions for marginalizing variables from a Bayes Tree. + */ +template +class GTSAM_EXPORT BayesTreeMarginalizationHelper { + +public: + using Clique = typename BayesTree::Clique; + using sharedClique = typename BayesTree::sharedClique; + + /** + * This function identifies variables that need to be re-eliminated before + * performing marginalization. + * + * Re-elimination is necessary for a clique containing marginalizable + * variables if: + * + * 1. Some non-marginalizable variables appear before marginalizable ones + * in that clique; + * 2. Or it has a child node depending on a marginalizable variable AND the + * subtree rooted at that child contains non-marginalizables. + * + * In addition, for any descendant node depending on a marginalizable + * variable, if the subtree rooted at that descendant contains + * non-marginalizable variables (i.e., it lies on a path from one of the + * aforementioned cliques that require re-elimination to a node containing + * non-marginalizable variables at the leaf side), then it also needs to + * be re-eliminated. + * + * @param[in] bayesTree The Bayes tree + * @param[in] marginalizableKeys Keys to be marginalized + * @return Set of additional keys that need to be re-eliminated + */ + static std::unordered_set + gatherAdditionalKeysToReEliminate( + const BayesTree& bayesTree, + const KeyVector& marginalizableKeys) { + const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); + + std::unordered_set additionalCliques = + gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); + + std::unordered_set additionalKeys; + for (const Clique* clique : additionalCliques) { + addCliqueToKeySet(clique, &additionalKeys); + } + + if (debug) { + std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; + for (const Key& key : additionalKeys) { + std::cout << DefaultKeyFormatter(key) << " "; + } + std::cout << std::endl; + } + + return additionalKeys; + } + + protected: + /** + * This function identifies cliques that need to be re-eliminated before + * performing marginalization. + * See the docstring of @ref gatherAdditionalKeysToReEliminate(). + */ + static std::unordered_set + gatherAdditionalCliquesToReEliminate( + const BayesTree& bayesTree, + const KeyVector& marginalizableKeys) { + std::unordered_set additionalCliques; + std::unordered_set marginalizableKeySet( + marginalizableKeys.begin(), marginalizableKeys.end()); + CachedSearch cachedSearch; + + // Check each clique that contains a marginalizable key + for (const Clique* clique : + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { + if (additionalCliques.count(clique)) { + // The clique has already been visited. This can happen when an + // ancestor of the current clique also contain some marginalizable + // varaibles and it's processed beore the current. + continue; + } + + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { + // Add the current clique + additionalCliques.insert(clique); + + // Then add the dependent cliques + gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, + &cachedSearch); + } + } + return additionalCliques; + } + + /** + * Gather the cliques containing any of the given keys. + * + * @param[in] bayesTree The Bayes tree + * @param[in] keysOfInterest Set of keys of interest + * @return Set of cliques that contain any of the given keys + */ + static std::unordered_set getCliquesContainingKeys( + const BayesTree& bayesTree, + const std::unordered_set& keysOfInterest) { + std::unordered_set cliques; + for (const Key& key : keysOfInterest) { + cliques.insert(bayesTree[key].get()); + } + return cliques; + } + + /** + * A struct to cache the results of the below two functions. + */ + struct CachedSearch { + std::unordered_map wholeMarginalizableCliques; + std::unordered_map wholeMarginalizableSubtrees; + }; + + /** + * Check if all variables in the clique are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeCliqueMarginalizable( + const Clique* clique, + const std::unordered_set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableCliques.find(clique); + if (it != cache->wholeMarginalizableCliques.end()) { + return it->second; + } else { + bool ret = true; + for (Key key : clique->conditional()->frontals()) { + if (!marginalizableKeys.count(key)) { + ret = false; + break; + } + } + cache->wholeMarginalizableCliques.insert({clique, ret}); + return ret; + } + } + + /** + * Check if all variables in the subtree are marginalizable. + * + * Note we use a cache map to avoid repeated searches. + */ + static bool isWholeSubtreeMarginalizable( + const Clique* subtree, + const std::unordered_set& marginalizableKeys, + CachedSearch* cache) { + auto it = cache->wholeMarginalizableSubtrees.find(subtree); + if (it != cache->wholeMarginalizableSubtrees.end()) { + return it->second; + } else { + bool ret = true; + if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { + for (const sharedClique& child : subtree->children) { + if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { + ret = false; + break; + } + } + } else { + ret = false; + } + cache->wholeMarginalizableSubtrees.insert({subtree, ret}); + return ret; + } + } + + /** + * Check if a clique contains variables that need reelimination due to + * elimination ordering conflicts. + * + * @param[in] clique The clique to check + * @param[in] marginalizableKeys Set of keys to be marginalized + * @return true if any variables in the clique need re-elimination + */ + static bool needsReelimination( + const Clique* clique, + const std::unordered_set& marginalizableKeys, + CachedSearch* cache) { + bool hasNonMarginalizableAhead = false; + + // Check each frontal variable in order + for (Key key : clique->conditional()->frontals()) { + if (marginalizableKeys.count(key)) { + // If we've seen non-marginalizable variables before this one, + // we need to reeliminate + if (hasNonMarginalizableAhead) { + return true; + } + + // Check if any child depends on this marginalizable key and the + // subtree rooted at that child contains non-marginalizables. + for (const sharedClique& child : clique->children) { + if (hasDependency(child.get(), key) && + !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { + return true; + } + } + } else { + hasNonMarginalizableAhead = true; + } + } + return false; + } + + /** + * Gather all dependent nodes that lie on a path from the root clique + * to a clique containing a non-marginalizable variable at the leaf side. + * + * @param[in] rootClique The root clique + * @param[in] marginalizableKeys Set of keys to be marginalized + */ + static void gatherDependentCliques( + const Clique* rootClique, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, + CachedSearch* cache) { + std::vector dependentChildren; + dependentChildren.reserve(rootClique->children.size()); + for (const sharedClique& child : rootClique->children) { + if (additionalCliques->count(child.get())) { + // This child has already been visited. This can happen if the + // child itself contains a marginalizable variable and it's + // processed before the current rootClique. + continue; + } + if (hasDependency(child.get(), marginalizableKeys)) { + dependentChildren.push_back(child.get()); + } + } + gatherDependentCliquesFromChildren( + dependentChildren, marginalizableKeys, additionalCliques, cache); + } + + /** + * A helper function for the above gatherDependentCliques(). + */ + static void gatherDependentCliquesFromChildren( + const std::vector& dependentChildren, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, + CachedSearch* cache) { + std::deque descendants( + dependentChildren.begin(), dependentChildren.end()); + while (!descendants.empty()) { + const Clique* descendant = descendants.front(); + descendants.pop_front(); + + // If the subtree rooted at this descendant contains non-marginalizables, + // it must lie on a path from the root clique to a clique containing + // non-marginalizables at the leaf side. + if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { + additionalCliques->insert(descendant); + + // Add children of the current descendant to the set descendants. + for (const sharedClique& child : descendant->children) { + if (additionalCliques->count(child.get())) { + // This child has already been visited. + continue; + } else { + descendants.push_back(child.get()); + } + } + } + } + } + + /** + * Add all frontal variables from a clique to a key set. + * + * @param[in] clique Clique to add keys from + * @param[out] additionalKeys Pointer to the output key set + */ + static void addCliqueToKeySet( + const Clique* clique, + std::unordered_set* additionalKeys) { + for (Key key : clique->conditional()->frontals()) { + additionalKeys->insert(key); + } + } + + /** + * Check if the clique depends on the given key. + * + * @param[in] clique Clique to check + * @param[in] key Key to check for dependencies + * @return true if clique depends on the key + */ + static bool hasDependency( + const Clique* clique, Key key) { + auto& conditional = clique->conditional(); + if (std::find(conditional->beginParents(), + conditional->endParents(), key) + != conditional->endParents()) { + return true; + } else { + return false; + } + } + + /** + * Check if the clique depends on any of the given keys. + */ + static bool hasDependency( + const Clique* clique, const std::unordered_set& keys) { + auto& conditional = clique->conditional(); + for (auto it = conditional->beginParents(); + it != conditional->endParents(); ++it) { + if (keys.count(*it)) { + return true; + } + } + + return false; + } +}; +// BayesTreeMarginalizationHelper + +}/// namespace gtsam diff --git a/gtsam/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam/nonlinear/IncrementalFixedLagSmoother.cpp index ed6f2dba3..394eb77b1 100644 --- a/gtsam/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam/nonlinear/IncrementalFixedLagSmoother.cpp @@ -20,7 +20,7 @@ */ #include -#include +#include #include namespace gtsam { diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 1e367e55c..6e7d699b4 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -9,350 +9,13 @@ * -------------------------------------------------------------------------- */ -/** - * @file BayesTreeMarginalizationHelper.h - * @brief Helper functions for marginalizing variables from a Bayes Tree. - * - * @author Jeffrey (Zhiwei Wang) - * @date Oct 28, 2024 - */ - -// \callgraph #pragma once -#include -#include -#include -#include -#include -#include -#include "gtsam_unstable/dllexport.h" +#ifdef _MSC_VER +#pragma message("BayesTreeMarginalizationHelper was moved to the gtsam/nonlinear directory") +#else +#warning "BayesTreeMarginalizationHelper was moved to the gtsam/nonlinear directory" +#endif -namespace gtsam { -/** - * This class provides helper functions for marginalizing variables from a Bayes Tree. - */ -template -class GTSAM_UNSTABLE_EXPORT BayesTreeMarginalizationHelper { - -public: - using Clique = typename BayesTree::Clique; - using sharedClique = typename BayesTree::sharedClique; - - /** - * This function identifies variables that need to be re-eliminated before - * performing marginalization. - * - * Re-elimination is necessary for a clique containing marginalizable - * variables if: - * - * 1. Some non-marginalizable variables appear before marginalizable ones - * in that clique; - * 2. Or it has a child node depending on a marginalizable variable AND the - * subtree rooted at that child contains non-marginalizables. - * - * In addition, for any descendant node depending on a marginalizable - * variable, if the subtree rooted at that descendant contains - * non-marginalizable variables (i.e., it lies on a path from one of the - * aforementioned cliques that require re-elimination to a node containing - * non-marginalizable variables at the leaf side), then it also needs to - * be re-eliminated. - * - * @param[in] bayesTree The Bayes tree - * @param[in] marginalizableKeys Keys to be marginalized - * @return Set of additional keys that need to be re-eliminated - */ - static std::unordered_set - gatherAdditionalKeysToReEliminate( - const BayesTree& bayesTree, - const KeyVector& marginalizableKeys) { - const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); - - std::unordered_set additionalCliques = - gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); - - std::unordered_set additionalKeys; - for (const Clique* clique : additionalCliques) { - addCliqueToKeySet(clique, &additionalKeys); - } - - if (debug) { - std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; - for (const Key& key : additionalKeys) { - std::cout << DefaultKeyFormatter(key) << " "; - } - std::cout << std::endl; - } - - return additionalKeys; - } - - protected: - /** - * This function identifies cliques that need to be re-eliminated before - * performing marginalization. - * See the docstring of @ref gatherAdditionalKeysToReEliminate(). - */ - static std::unordered_set - gatherAdditionalCliquesToReEliminate( - const BayesTree& bayesTree, - const KeyVector& marginalizableKeys) { - std::unordered_set additionalCliques; - std::unordered_set marginalizableKeySet( - marginalizableKeys.begin(), marginalizableKeys.end()); - CachedSearch cachedSearch; - - // Check each clique that contains a marginalizable key - for (const Clique* clique : - getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { - if (additionalCliques.count(clique)) { - // The clique has already been visited. This can happen when an - // ancestor of the current clique also contain some marginalizable - // varaibles and it's processed beore the current. - continue; - } - - if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { - // Add the current clique - additionalCliques.insert(clique); - - // Then add the dependent cliques - gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, - &cachedSearch); - } - } - return additionalCliques; - } - - /** - * Gather the cliques containing any of the given keys. - * - * @param[in] bayesTree The Bayes tree - * @param[in] keysOfInterest Set of keys of interest - * @return Set of cliques that contain any of the given keys - */ - static std::unordered_set getCliquesContainingKeys( - const BayesTree& bayesTree, - const std::unordered_set& keysOfInterest) { - std::unordered_set cliques; - for (const Key& key : keysOfInterest) { - cliques.insert(bayesTree[key].get()); - } - return cliques; - } - - /** - * A struct to cache the results of the below two functions. - */ - struct CachedSearch { - std::unordered_map wholeMarginalizableCliques; - std::unordered_map wholeMarginalizableSubtrees; - }; - - /** - * Check if all variables in the clique are marginalizable. - * - * Note we use a cache map to avoid repeated searches. - */ - static bool isWholeCliqueMarginalizable( - const Clique* clique, - const std::unordered_set& marginalizableKeys, - CachedSearch* cache) { - auto it = cache->wholeMarginalizableCliques.find(clique); - if (it != cache->wholeMarginalizableCliques.end()) { - return it->second; - } else { - bool ret = true; - for (Key key : clique->conditional()->frontals()) { - if (!marginalizableKeys.count(key)) { - ret = false; - break; - } - } - cache->wholeMarginalizableCliques.insert({clique, ret}); - return ret; - } - } - - /** - * Check if all variables in the subtree are marginalizable. - * - * Note we use a cache map to avoid repeated searches. - */ - static bool isWholeSubtreeMarginalizable( - const Clique* subtree, - const std::unordered_set& marginalizableKeys, - CachedSearch* cache) { - auto it = cache->wholeMarginalizableSubtrees.find(subtree); - if (it != cache->wholeMarginalizableSubtrees.end()) { - return it->second; - } else { - bool ret = true; - if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { - for (const sharedClique& child : subtree->children) { - if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { - ret = false; - break; - } - } - } else { - ret = false; - } - cache->wholeMarginalizableSubtrees.insert({subtree, ret}); - return ret; - } - } - - /** - * Check if a clique contains variables that need reelimination due to - * elimination ordering conflicts. - * - * @param[in] clique The clique to check - * @param[in] marginalizableKeys Set of keys to be marginalized - * @return true if any variables in the clique need re-elimination - */ - static bool needsReelimination( - const Clique* clique, - const std::unordered_set& marginalizableKeys, - CachedSearch* cache) { - bool hasNonMarginalizableAhead = false; - - // Check each frontal variable in order - for (Key key : clique->conditional()->frontals()) { - if (marginalizableKeys.count(key)) { - // If we've seen non-marginalizable variables before this one, - // we need to reeliminate - if (hasNonMarginalizableAhead) { - return true; - } - - // Check if any child depends on this marginalizable key and the - // subtree rooted at that child contains non-marginalizables. - for (const sharedClique& child : clique->children) { - if (hasDependency(child.get(), key) && - !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { - return true; - } - } - } else { - hasNonMarginalizableAhead = true; - } - } - return false; - } - - /** - * Gather all dependent nodes that lie on a path from the root clique - * to a clique containing a non-marginalizable variable at the leaf side. - * - * @param[in] rootClique The root clique - * @param[in] marginalizableKeys Set of keys to be marginalized - */ - static void gatherDependentCliques( - const Clique* rootClique, - const std::unordered_set& marginalizableKeys, - std::unordered_set* additionalCliques, - CachedSearch* cache) { - std::vector dependentChildren; - dependentChildren.reserve(rootClique->children.size()); - for (const sharedClique& child : rootClique->children) { - if (additionalCliques->count(child.get())) { - // This child has already been visited. This can happen if the - // child itself contains a marginalizable variable and it's - // processed before the current rootClique. - continue; - } - if (hasDependency(child.get(), marginalizableKeys)) { - dependentChildren.push_back(child.get()); - } - } - gatherDependentCliquesFromChildren( - dependentChildren, marginalizableKeys, additionalCliques, cache); - } - - /** - * A helper function for the above gatherDependentCliques(). - */ - static void gatherDependentCliquesFromChildren( - const std::vector& dependentChildren, - const std::unordered_set& marginalizableKeys, - std::unordered_set* additionalCliques, - CachedSearch* cache) { - std::deque descendants( - dependentChildren.begin(), dependentChildren.end()); - while (!descendants.empty()) { - const Clique* descendant = descendants.front(); - descendants.pop_front(); - - // If the subtree rooted at this descendant contains non-marginalizables, - // it must lie on a path from the root clique to a clique containing - // non-marginalizables at the leaf side. - if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { - additionalCliques->insert(descendant); - - // Add children of the current descendant to the set descendants. - for (const sharedClique& child : descendant->children) { - if (additionalCliques->count(child.get())) { - // This child has already been visited. - continue; - } else { - descendants.push_back(child.get()); - } - } - } - } - } - - /** - * Add all frontal variables from a clique to a key set. - * - * @param[in] clique Clique to add keys from - * @param[out] additionalKeys Pointer to the output key set - */ - static void addCliqueToKeySet( - const Clique* clique, - std::unordered_set* additionalKeys) { - for (Key key : clique->conditional()->frontals()) { - additionalKeys->insert(key); - } - } - - /** - * Check if the clique depends on the given key. - * - * @param[in] clique Clique to check - * @param[in] key Key to check for dependencies - * @return true if clique depends on the key - */ - static bool hasDependency( - const Clique* clique, Key key) { - auto& conditional = clique->conditional(); - if (std::find(conditional->beginParents(), - conditional->endParents(), key) - != conditional->endParents()) { - return true; - } else { - return false; - } - } - - /** - * Check if the clique depends on any of the given keys. - */ - static bool hasDependency( - const Clique* clique, const std::unordered_set& keys) { - auto& conditional = clique->conditional(); - for (auto it = conditional->beginParents(); - it != conditional->endParents(); ++it) { - if (keys.count(*it)) { - return true; - } - } - - return false; - } -}; -// BayesTreeMarginalizationHelper - -}/// namespace gtsam +#include \ No newline at end of file