diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 04e4e0cac..724814fd0 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -50,9 +51,12 @@ public: * 2. Or it has a child node depending on a marginalizable variable AND the * subtree rooted at that child contains non-marginalizables. * - * In addition, the subtrees under the aforementioned cliques that require - * re-elimination, which contain non-marginalizable variables in their root - * node, also need to be re-eliminated. + * In addition, for any descendant node depending on a marginalizable + * variable, if the subtree rooted at that descendant contains + * non-marginalizable variables (i.e., it lies on a path from one of the + * aforementioned cliques that require re-elimination to a node containing + * non-marginalizable variables at the leaf side), then it also needs to + * be re-eliminated. * * @param[in] bayesTree The Bayes tree * @param[in] marginalizableKeys Keys to be marginalized @@ -66,7 +70,7 @@ public: std::set additionalKeys; std::set marginalizableKeySet( marginalizableKeys.begin(), marginalizableKeys.end()); - std::set dependentSubtrees; + std::set dependentCliques; CachedSearch cachedSearch; // Check each clique that contains a marginalizable key @@ -77,17 +81,14 @@ public: // Add frontal variables from current clique addCliqueToKeySet(clique, &additionalKeys); - // Then gather dependent subtrees to be added later - gatherDependentSubtrees( - clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); + // Then add the dependent cliques + for (const sharedClique& dependent : + gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { + addCliqueToKeySet(dependent, &additionalKeys); + } } } - // Add the remaining dependent cliques - for (const sharedClique& subtree : dependentSubtrees) { - addSubtreeToKeySet(subtree, &additionalKeys); - } - if (debug) { std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; for (const Key& key : additionalKeys) { @@ -219,53 +220,53 @@ public: } /** - * Gather all subtrees that depend on a marginalizable key and contain - * non-marginalizable variables in their root. + * Gather all dependent nodes that lie on a path from the root clique + * to a clique containing a non-marginalizable variable at the leaf side. * - * @param[in] rootClique The starting clique + * @param[in] rootClique The root clique * @param[in] marginalizableKeys Set of keys to be marginalized - * @param[out] dependentSubtrees Pointer to set storing dependent cliques */ - static void gatherDependentSubtrees( + static std::set gatherDependentCliques( const sharedClique& rootClique, const std::set& marginalizableKeys, - std::set* dependentSubtrees, CachedSearch* cache) { - for (Key key : rootClique->conditional()->frontals()) { - if (marginalizableKeys.count(key)) { - // Find children that depend on this key - for (const sharedClique& child : rootClique->children) { - if (!dependentSubtrees->count(child) && - hasDependency(child, key)) { - getSubtreesContainingNonMarginalizables( - child, marginalizableKeys, cache, dependentSubtrees); - } - } + std::vector dependentChildren; + dependentChildren.reserve(rootClique->children.size()); + for (const sharedClique& child : rootClique->children) { + if (hasDependency(child, marginalizableKeys)) { + dependentChildren.push_back(child); } } + return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); } /** - * Gather all subtrees that contain non-marginalizable variables in its root. + * A helper function for the above gatherDependentCliques(). */ - static void getSubtreesContainingNonMarginalizables( - const sharedClique& rootClique, + static std::set gatherDependentCliquesFromChildren( + const std::vector& dependentChildren, const std::set& marginalizableKeys, - CachedSearch* cache, - std::set* subtreesContainingNonMarginalizables) { - // If the root clique itself contains non-marginalizable variables, we - // just add it to subtreesContainingNonMarginalizables; - if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { - subtreesContainingNonMarginalizables->insert(rootClique); - return; - } + CachedSearch* cache) { + std::deque descendants( + dependentChildren.begin(), dependentChildren.end()); + std::set dependentCliques; + while (!descendants.empty()) { + sharedClique descendant = descendants.front(); + descendants.pop_front(); - // Otherwise, we need to recursively check the children - for (const sharedClique& child : rootClique->children) { - getSubtreesContainingNonMarginalizables( - child, marginalizableKeys, cache, - subtreesContainingNonMarginalizables); + // If the subtree rooted at this descendant contains non-marginalizables, + // it must lie on a path from the root clique to a clique containing + // non-marginalizables at the leaf side. + if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { + dependentCliques.insert(descendant); + } + + // Add all children of the current descendant to the set descendants. + for (const sharedClique& child : descendant->children) { + descendants.push_back(child); + } } + return dependentCliques; } /** @@ -282,28 +283,6 @@ public: } } - /** - * Add all frontal variables from a subtree to a key set. - * - * @param[in] subRoot Root clique of the subtree - * @param[out] additionalKeys Pointer to the output key set - */ - static void addSubtreeToKeySet( - const sharedClique& subRoot, - std::set* additionalKeys) { - std::set cliques; - cliques.insert(subRoot); - while(!cliques.empty()) { - auto begin = cliques.begin(); - sharedClique clique = *begin; - cliques.erase(begin); - addCliqueToKeySet(clique, additionalKeys); - for (const sharedClique& child : clique->children) { - cliques.insert(child); - } - } - } - /** * Check if the clique depends on the given key. * @@ -322,6 +301,19 @@ public: return false; } } + + /** + * Check if the clique depends on any of the given keys. + */ + static bool hasDependency( + const sharedClique& clique, const std::set& keys) { + for (Key key : keys) { + if (hasDependency(clique, key)) { + return true; + } + } + return false; + } }; // BayesTreeMarginalizationHelper