Some refinement in BayesTreeMarginalizationHelper:

1. Skip subtrees that have already been visited when searching for
   dependent cliques;
2. Avoid copying shared_ptrs (which needs extra expensive atomic
   operations) in the searching. Use const Clique* instead of
   sharedClique whenever possible;
3. Use std::unordered_set instead of std::set to improve average
   searching speed.
release/4.3a0
Jeffrey 2024-11-02 17:14:01 +08:00
parent 0d9c3a9958
commit 06dac43cae
2 changed files with 104 additions and 65 deletions

View File

@ -21,6 +21,7 @@
#pragma once #pragma once
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <deque> #include <deque>
#include <gtsam/inference/BayesTree.h> #include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
@ -62,30 +63,18 @@ public:
* @param[in] marginalizableKeys Keys to be marginalized * @param[in] marginalizableKeys Keys to be marginalized
* @return Set of additional keys that need to be re-eliminated * @return Set of additional keys that need to be re-eliminated
*/ */
static std::set<Key> gatherAdditionalKeysToReEliminate( static std::unordered_set<Key>
gatherAdditionalKeysToReEliminate(
const BayesTree& bayesTree, const BayesTree& bayesTree,
const KeyVector& marginalizableKeys) { const KeyVector& marginalizableKeys) {
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
std::set<Key> additionalKeys; std::unordered_set<const Clique*> additionalCliques =
std::set<Key> marginalizableKeySet( gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys);
marginalizableKeys.begin(), marginalizableKeys.end());
CachedSearch cachedSearch;
// Check each clique that contains a marginalizable key std::unordered_set<Key> additionalKeys;
for (const sharedClique& clique : for (const Clique* clique : additionalCliques) {
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
// Add frontal variables from current clique
addCliqueToKeySet(clique, &additionalKeys); addCliqueToKeySet(clique, &additionalKeys);
// Then add the dependent cliques
for (const sharedClique& dependent :
gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
addCliqueToKeySet(dependent, &additionalKeys);
}
}
} }
if (debug) { if (debug) {
@ -100,6 +89,41 @@ public:
} }
protected: protected:
/**
* This function identifies cliques that need to be re-eliminated before
* performing marginalization.
* See the docstring of @ref gatherAdditionalKeysToReEliminate().
*/
static std::unordered_set<const Clique*>
gatherAdditionalCliquesToReEliminate(
const BayesTree& bayesTree,
const KeyVector& marginalizableKeys) {
std::unordered_set<const Clique*> additionalCliques;
std::unordered_set<Key> marginalizableKeySet(
marginalizableKeys.begin(), marginalizableKeys.end());
CachedSearch cachedSearch;
// Check each clique that contains a marginalizable key
for (const Clique* clique :
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
if (additionalCliques.count(clique)) {
// The clique has already been visited. This can happen when an
// ancestor of the current clique also contain some marginalizable
// varaibles and it's processed beore the current.
continue;
}
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
// Add the current clique
additionalCliques.insert(clique);
// Then add the dependent cliques
gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques,
&cachedSearch);
}
}
return additionalCliques;
}
/** /**
* Gather the cliques containing any of the given keys. * Gather the cliques containing any of the given keys.
@ -108,12 +132,12 @@ public:
* @param[in] keysOfInterest Set of keys of interest * @param[in] keysOfInterest Set of keys of interest
* @return Set of cliques that contain any of the given keys * @return Set of cliques that contain any of the given keys
*/ */
static std::set<sharedClique> getCliquesContainingKeys( static std::unordered_set<const Clique*> getCliquesContainingKeys(
const BayesTree& bayesTree, const BayesTree& bayesTree,
const std::set<Key>& keysOfInterest) { const std::unordered_set<Key>& keysOfInterest) {
std::set<sharedClique> cliques; std::unordered_set<const Clique*> cliques;
for (const Key& key : keysOfInterest) { for (const Key& key : keysOfInterest) {
cliques.insert(bayesTree[key]); cliques.insert(bayesTree[key].get());
} }
return cliques; return cliques;
} }
@ -122,8 +146,8 @@ public:
* A struct to cache the results of the below two functions. * A struct to cache the results of the below two functions.
*/ */
struct CachedSearch { struct CachedSearch {
std::unordered_map<Clique*, bool> wholeMarginalizableCliques; std::unordered_map<const Clique*, bool> wholeMarginalizableCliques;
std::unordered_map<Clique*, bool> wholeMarginalizableSubtrees; std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees;
}; };
/** /**
@ -132,10 +156,10 @@ public:
* Note we use a cache map to avoid repeated searches. * Note we use a cache map to avoid repeated searches.
*/ */
static bool isWholeCliqueMarginalizable( static bool isWholeCliqueMarginalizable(
const sharedClique& clique, const Clique* clique,
const std::set<Key>& marginalizableKeys, const std::unordered_set<Key>& marginalizableKeys,
CachedSearch* cache) { CachedSearch* cache) {
auto it = cache->wholeMarginalizableCliques.find(clique.get()); auto it = cache->wholeMarginalizableCliques.find(clique);
if (it != cache->wholeMarginalizableCliques.end()) { if (it != cache->wholeMarginalizableCliques.end()) {
return it->second; return it->second;
} else { } else {
@ -146,7 +170,7 @@ public:
break; break;
} }
} }
cache->wholeMarginalizableCliques.insert({clique.get(), ret}); cache->wholeMarginalizableCliques.insert({clique, ret});
return ret; return ret;
} }
} }
@ -157,17 +181,17 @@ public:
* Note we use a cache map to avoid repeated searches. * Note we use a cache map to avoid repeated searches.
*/ */
static bool isWholeSubtreeMarginalizable( static bool isWholeSubtreeMarginalizable(
const sharedClique& subtree, const Clique* subtree,
const std::set<Key>& marginalizableKeys, const std::unordered_set<Key>& marginalizableKeys,
CachedSearch* cache) { CachedSearch* cache) {
auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); auto it = cache->wholeMarginalizableSubtrees.find(subtree);
if (it != cache->wholeMarginalizableSubtrees.end()) { if (it != cache->wholeMarginalizableSubtrees.end()) {
return it->second; return it->second;
} else { } else {
bool ret = true; bool ret = true;
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) {
for (const sharedClique& child : subtree->children) { for (const sharedClique& child : subtree->children) {
if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
ret = false; ret = false;
break; break;
} }
@ -175,7 +199,7 @@ public:
} else { } else {
ret = false; ret = false;
} }
cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); cache->wholeMarginalizableSubtrees.insert({subtree, ret});
return ret; return ret;
} }
} }
@ -189,8 +213,8 @@ public:
* @return true if any variables in the clique need re-elimination * @return true if any variables in the clique need re-elimination
*/ */
static bool needsReelimination( static bool needsReelimination(
const sharedClique& clique, const Clique* clique,
const std::set<Key>& marginalizableKeys, const std::unordered_set<Key>& marginalizableKeys,
CachedSearch* cache) { CachedSearch* cache) {
bool hasNonMarginalizableAhead = false; bool hasNonMarginalizableAhead = false;
@ -206,8 +230,8 @@ public:
// Check if any child depends on this marginalizable key and the // Check if any child depends on this marginalizable key and the
// subtree rooted at that child contains non-marginalizables. // subtree rooted at that child contains non-marginalizables.
for (const sharedClique& child : clique->children) { for (const sharedClique& child : clique->children) {
if (hasDependency(child, key) && if (hasDependency(child.get(), key) &&
!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
return true; return true;
} }
} }
@ -225,47 +249,59 @@ public:
* @param[in] rootClique The root clique * @param[in] rootClique The root clique
* @param[in] marginalizableKeys Set of keys to be marginalized * @param[in] marginalizableKeys Set of keys to be marginalized
*/ */
static std::set<sharedClique> gatherDependentCliques( static void gatherDependentCliques(
const sharedClique& rootClique, const Clique* rootClique,
const std::set<Key>& marginalizableKeys, const std::unordered_set<Key>& marginalizableKeys,
std::unordered_set<const Clique*>* additionalCliques,
CachedSearch* cache) { CachedSearch* cache) {
std::vector<sharedClique> dependentChildren; std::vector<const Clique*> dependentChildren;
dependentChildren.reserve(rootClique->children.size()); dependentChildren.reserve(rootClique->children.size());
for (const sharedClique& child : rootClique->children) { for (const sharedClique& child : rootClique->children) {
if (hasDependency(child, marginalizableKeys)) { if (additionalCliques->count(child.get())) {
dependentChildren.push_back(child); // This child has already been visited. This can happen if the
// child itself contains a marginalizable variable and it's
// processed before the current rootClique.
continue;
}
if (hasDependency(child.get(), marginalizableKeys)) {
dependentChildren.push_back(child.get());
} }
} }
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); gatherDependentCliquesFromChildren(
dependentChildren, marginalizableKeys, additionalCliques, cache);
} }
/** /**
* A helper function for the above gatherDependentCliques(). * A helper function for the above gatherDependentCliques().
*/ */
static std::set<sharedClique> gatherDependentCliquesFromChildren( static void gatherDependentCliquesFromChildren(
const std::vector<sharedClique>& dependentChildren, const std::vector<const Clique*>& dependentChildren,
const std::set<Key>& marginalizableKeys, const std::unordered_set<Key>& marginalizableKeys,
std::unordered_set<const Clique*>* additionalCliques,
CachedSearch* cache) { CachedSearch* cache) {
std::deque<sharedClique> descendants( std::deque<const Clique*> descendants(
dependentChildren.begin(), dependentChildren.end()); dependentChildren.begin(), dependentChildren.end());
std::set<sharedClique> dependentCliques;
while (!descendants.empty()) { while (!descendants.empty()) {
sharedClique descendant = descendants.front(); const Clique* descendant = descendants.front();
descendants.pop_front(); descendants.pop_front();
// If the subtree rooted at this descendant contains non-marginalizables, // If the subtree rooted at this descendant contains non-marginalizables,
// it must lie on a path from the root clique to a clique containing // it must lie on a path from the root clique to a clique containing
// non-marginalizables at the leaf side. // non-marginalizables at the leaf side.
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
dependentCliques.insert(descendant); additionalCliques->insert(descendant);
}
// Add all children of the current descendant to the set descendants. // Add children of the current descendant to the set descendants.
for (const sharedClique& child : descendant->children) { for (const sharedClique& child : descendant->children) {
descendants.push_back(child); if (additionalCliques->count(child.get())) {
// This child has already been visited.
continue;
} else {
descendants.push_back(child.get());
}
}
} }
} }
return dependentCliques;
} }
/** /**
@ -275,8 +311,8 @@ public:
* @param[out] additionalKeys Pointer to the output key set * @param[out] additionalKeys Pointer to the output key set
*/ */
static void addCliqueToKeySet( static void addCliqueToKeySet(
const sharedClique& clique, const Clique* clique,
std::set<Key>* additionalKeys) { std::unordered_set<Key>* additionalKeys) {
for (Key key : clique->conditional()->frontals()) { for (Key key : clique->conditional()->frontals()) {
additionalKeys->insert(key); additionalKeys->insert(key);
} }
@ -290,8 +326,8 @@ public:
* @return true if clique depends on the key * @return true if clique depends on the key
*/ */
static bool hasDependency( static bool hasDependency(
const sharedClique& clique, Key key) { const Clique* clique, Key key) {
auto conditional = clique->conditional(); auto& conditional = clique->conditional();
if (std::find(conditional->beginParents(), if (std::find(conditional->beginParents(),
conditional->endParents(), key) conditional->endParents(), key)
!= conditional->endParents()) { != conditional->endParents()) {
@ -305,12 +341,15 @@ public:
* Check if the clique depends on any of the given keys. * Check if the clique depends on any of the given keys.
*/ */
static bool hasDependency( static bool hasDependency(
const sharedClique& clique, const std::set<Key>& keys) { const Clique* clique, const std::unordered_set<Key>& keys) {
for (Key key : keys) { auto& conditional = clique->conditional();
if (hasDependency(clique, key)) { for (auto it = conditional->beginParents();
it != conditional->endParents(); ++it) {
if (keys.count(*it)) {
return true; return true;
} }
} }
return false; return false;
} }
}; };

View File

@ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update(
std::cout << std::endl; std::cout << std::endl;
} }
std::set<Key> additionalKeys = std::unordered_set<Key> additionalKeys =
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate( BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
isam_, marginalizableKeys); isam_, marginalizableKeys);
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end()); KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());