diff --git a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h index 53d030624..1e367e55c 100644 --- a/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h +++ b/gtsam_unstable/nonlinear/BayesTreeMarginalizationHelper.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -62,30 +63,18 @@ public: * @param[in] marginalizableKeys Keys to be marginalized * @return Set of additional keys that need to be re-eliminated */ - static std::set gatherAdditionalKeysToReEliminate( + static std::unordered_set + gatherAdditionalKeysToReEliminate( const BayesTree& bayesTree, const KeyVector& marginalizableKeys) { const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); - std::set additionalKeys; - std::set marginalizableKeySet( - marginalizableKeys.begin(), marginalizableKeys.end()); - CachedSearch cachedSearch; + std::unordered_set additionalCliques = + gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); - // Check each clique that contains a marginalizable key - for (const sharedClique& clique : - getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { - - if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { - // Add frontal variables from current clique - addCliqueToKeySet(clique, &additionalKeys); - - // Then add the dependent cliques - for (const sharedClique& dependent : - gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { - addCliqueToKeySet(dependent, &additionalKeys); - } - } + std::unordered_set additionalKeys; + for (const Clique* clique : additionalCliques) { + addCliqueToKeySet(clique, &additionalKeys); } if (debug) { @@ -100,6 +89,41 @@ public: } protected: + /** + * This function identifies cliques that need to be re-eliminated before + * performing marginalization. + * See the docstring of @ref gatherAdditionalKeysToReEliminate(). + */ + static std::unordered_set + gatherAdditionalCliquesToReEliminate( + const BayesTree& bayesTree, + const KeyVector& marginalizableKeys) { + std::unordered_set additionalCliques; + std::unordered_set marginalizableKeySet( + marginalizableKeys.begin(), marginalizableKeys.end()); + CachedSearch cachedSearch; + + // Check each clique that contains a marginalizable key + for (const Clique* clique : + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { + if (additionalCliques.count(clique)) { + // The clique has already been visited. This can happen when an + // ancestor of the current clique also contain some marginalizable + // varaibles and it's processed beore the current. + continue; + } + + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { + // Add the current clique + additionalCliques.insert(clique); + + // Then add the dependent cliques + gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, + &cachedSearch); + } + } + return additionalCliques; + } /** * Gather the cliques containing any of the given keys. @@ -108,12 +132,12 @@ public: * @param[in] keysOfInterest Set of keys of interest * @return Set of cliques that contain any of the given keys */ - static std::set getCliquesContainingKeys( + static std::unordered_set getCliquesContainingKeys( const BayesTree& bayesTree, - const std::set& keysOfInterest) { - std::set cliques; + const std::unordered_set& keysOfInterest) { + std::unordered_set cliques; for (const Key& key : keysOfInterest) { - cliques.insert(bayesTree[key]); + cliques.insert(bayesTree[key].get()); } return cliques; } @@ -122,8 +146,8 @@ public: * A struct to cache the results of the below two functions. */ struct CachedSearch { - std::unordered_map wholeMarginalizableCliques; - std::unordered_map wholeMarginalizableSubtrees; + std::unordered_map wholeMarginalizableCliques; + std::unordered_map wholeMarginalizableSubtrees; }; /** @@ -132,10 +156,10 @@ public: * Note we use a cache map to avoid repeated searches. */ static bool isWholeCliqueMarginalizable( - const sharedClique& clique, - const std::set& marginalizableKeys, + const Clique* clique, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { - auto it = cache->wholeMarginalizableCliques.find(clique.get()); + auto it = cache->wholeMarginalizableCliques.find(clique); if (it != cache->wholeMarginalizableCliques.end()) { return it->second; } else { @@ -146,7 +170,7 @@ public: break; } } - cache->wholeMarginalizableCliques.insert({clique.get(), ret}); + cache->wholeMarginalizableCliques.insert({clique, ret}); return ret; } } @@ -157,17 +181,17 @@ public: * Note we use a cache map to avoid repeated searches. */ static bool isWholeSubtreeMarginalizable( - const sharedClique& subtree, - const std::set& marginalizableKeys, + const Clique* subtree, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { - auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); + auto it = cache->wholeMarginalizableSubtrees.find(subtree); if (it != cache->wholeMarginalizableSubtrees.end()) { return it->second; } else { bool ret = true; if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { for (const sharedClique& child : subtree->children) { - if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { ret = false; break; } @@ -175,7 +199,7 @@ public: } else { ret = false; } - cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); + cache->wholeMarginalizableSubtrees.insert({subtree, ret}); return ret; } } @@ -189,8 +213,8 @@ public: * @return true if any variables in the clique need re-elimination */ static bool needsReelimination( - const sharedClique& clique, - const std::set& marginalizableKeys, + const Clique* clique, + const std::unordered_set& marginalizableKeys, CachedSearch* cache) { bool hasNonMarginalizableAhead = false; @@ -206,8 +230,8 @@ public: // Check if any child depends on this marginalizable key and the // subtree rooted at that child contains non-marginalizables. for (const sharedClique& child : clique->children) { - if (hasDependency(child, key) && - !isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { + if (hasDependency(child.get(), key) && + !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { return true; } } @@ -225,47 +249,59 @@ public: * @param[in] rootClique The root clique * @param[in] marginalizableKeys Set of keys to be marginalized */ - static std::set gatherDependentCliques( - const sharedClique& rootClique, - const std::set& marginalizableKeys, + static void gatherDependentCliques( + const Clique* rootClique, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, CachedSearch* cache) { - std::vector dependentChildren; + std::vector dependentChildren; dependentChildren.reserve(rootClique->children.size()); for (const sharedClique& child : rootClique->children) { - if (hasDependency(child, marginalizableKeys)) { - dependentChildren.push_back(child); + if (additionalCliques->count(child.get())) { + // This child has already been visited. This can happen if the + // child itself contains a marginalizable variable and it's + // processed before the current rootClique. + continue; + } + if (hasDependency(child.get(), marginalizableKeys)) { + dependentChildren.push_back(child.get()); } } - return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); + gatherDependentCliquesFromChildren( + dependentChildren, marginalizableKeys, additionalCliques, cache); } /** * A helper function for the above gatherDependentCliques(). */ - static std::set gatherDependentCliquesFromChildren( - const std::vector& dependentChildren, - const std::set& marginalizableKeys, + static void gatherDependentCliquesFromChildren( + const std::vector& dependentChildren, + const std::unordered_set& marginalizableKeys, + std::unordered_set* additionalCliques, CachedSearch* cache) { - std::deque descendants( + std::deque descendants( dependentChildren.begin(), dependentChildren.end()); - std::set dependentCliques; while (!descendants.empty()) { - sharedClique descendant = descendants.front(); + const Clique* descendant = descendants.front(); descendants.pop_front(); // If the subtree rooted at this descendant contains non-marginalizables, // it must lie on a path from the root clique to a clique containing // non-marginalizables at the leaf side. if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { - dependentCliques.insert(descendant); - } + additionalCliques->insert(descendant); - // Add all children of the current descendant to the set descendants. - for (const sharedClique& child : descendant->children) { - descendants.push_back(child); + // Add children of the current descendant to the set descendants. + for (const sharedClique& child : descendant->children) { + if (additionalCliques->count(child.get())) { + // This child has already been visited. + continue; + } else { + descendants.push_back(child.get()); + } + } } } - return dependentCliques; } /** @@ -275,8 +311,8 @@ public: * @param[out] additionalKeys Pointer to the output key set */ static void addCliqueToKeySet( - const sharedClique& clique, - std::set* additionalKeys) { + const Clique* clique, + std::unordered_set* additionalKeys) { for (Key key : clique->conditional()->frontals()) { additionalKeys->insert(key); } @@ -290,8 +326,8 @@ public: * @return true if clique depends on the key */ static bool hasDependency( - const sharedClique& clique, Key key) { - auto conditional = clique->conditional(); + const Clique* clique, Key key) { + auto& conditional = clique->conditional(); if (std::find(conditional->beginParents(), conditional->endParents(), key) != conditional->endParents()) { @@ -305,12 +341,15 @@ public: * Check if the clique depends on any of the given keys. */ static bool hasDependency( - const sharedClique& clique, const std::set& keys) { - for (Key key : keys) { - if (hasDependency(clique, key)) { + const Clique* clique, const std::unordered_set& keys) { + auto& conditional = clique->conditional(); + for (auto it = conditional->beginParents(); + it != conditional->endParents(); ++it) { + if (keys.count(*it)) { return true; } } + return false; } }; diff --git a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp index 238ff6b3d..0cd9ecbac 100644 --- a/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp +++ b/gtsam_unstable/nonlinear/IncrementalFixedLagSmoother.cpp @@ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( std::cout << std::endl; } - std::set additionalKeys = + std::unordered_set additionalKeys = BayesTreeMarginalizationHelper::gatherAdditionalKeysToReEliminate( isam_, marginalizableKeys); KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());