Further optimize the implementation of BayesTreeMarginalizationHelper:

Now we won't re-emilinate any unnecessary nodes (we re-emilinated
whole subtrees in the previous commits, which is not optimal)
release/4.3a0
Jeffrey 2024-10-31 21:52:45 +08:00
parent 14c3467520
commit 1a5e711f0e
1 changed files with 58 additions and 66 deletions

View File

@ -21,6 +21,7 @@
#pragma once #pragma once
#include <unordered_map> #include <unordered_map>
#include <deque>
#include <gtsam/inference/BayesTree.h> #include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
@ -50,9 +51,12 @@ public:
* 2. Or it has a child node depending on a marginalizable variable AND the * 2. Or it has a child node depending on a marginalizable variable AND the
* subtree rooted at that child contains non-marginalizables. * subtree rooted at that child contains non-marginalizables.
* *
* In addition, the subtrees under the aforementioned cliques that require * In addition, for any descendant node depending on a marginalizable
* re-elimination, which contain non-marginalizable variables in their root * variable, if the subtree rooted at that descendant contains
* node, also need to be re-eliminated. * non-marginalizable variables (i.e., it lies on a path from one of the
* aforementioned cliques that require re-elimination to a node containing
* non-marginalizable variables at the leaf side), then it also needs to
* be re-eliminated.
* *
* @param[in] bayesTree The Bayes tree * @param[in] bayesTree The Bayes tree
* @param[in] marginalizableKeys Keys to be marginalized * @param[in] marginalizableKeys Keys to be marginalized
@ -66,7 +70,7 @@ public:
std::set<Key> additionalKeys; std::set<Key> additionalKeys;
std::set<Key> marginalizableKeySet( std::set<Key> marginalizableKeySet(
marginalizableKeys.begin(), marginalizableKeys.end()); marginalizableKeys.begin(), marginalizableKeys.end());
std::set<sharedClique> dependentSubtrees; std::set<sharedClique> dependentCliques;
CachedSearch cachedSearch; CachedSearch cachedSearch;
// Check each clique that contains a marginalizable key // Check each clique that contains a marginalizable key
@ -77,17 +81,14 @@ public:
// Add frontal variables from current clique // Add frontal variables from current clique
addCliqueToKeySet(clique, &additionalKeys); addCliqueToKeySet(clique, &additionalKeys);
// Then gather dependent subtrees to be added later // Then add the dependent cliques
gatherDependentSubtrees( for (const sharedClique& dependent :
clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
addCliqueToKeySet(dependent, &additionalKeys);
}
} }
} }
// Add the remaining dependent cliques
for (const sharedClique& subtree : dependentSubtrees) {
addSubtreeToKeySet(subtree, &additionalKeys);
}
if (debug) { if (debug) {
std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: ";
for (const Key& key : additionalKeys) { for (const Key& key : additionalKeys) {
@ -219,53 +220,53 @@ public:
} }
/** /**
* Gather all subtrees that depend on a marginalizable key and contain * Gather all dependent nodes that lie on a path from the root clique
* non-marginalizable variables in their root. * to a clique containing a non-marginalizable variable at the leaf side.
* *
* @param[in] rootClique The starting clique * @param[in] rootClique The root clique
* @param[in] marginalizableKeys Set of keys to be marginalized * @param[in] marginalizableKeys Set of keys to be marginalized
* @param[out] dependentSubtrees Pointer to set storing dependent cliques
*/ */
static void gatherDependentSubtrees( static std::set<sharedClique> gatherDependentCliques(
const sharedClique& rootClique, const sharedClique& rootClique,
const std::set<Key>& marginalizableKeys, const std::set<Key>& marginalizableKeys,
std::set<sharedClique>* dependentSubtrees,
CachedSearch* cache) { CachedSearch* cache) {
for (Key key : rootClique->conditional()->frontals()) { std::vector<sharedClique> dependentChildren;
if (marginalizableKeys.count(key)) { dependentChildren.reserve(rootClique->children.size());
// Find children that depend on this key for (const sharedClique& child : rootClique->children) {
for (const sharedClique& child : rootClique->children) { if (hasDependency(child, marginalizableKeys)) {
if (!dependentSubtrees->count(child) && dependentChildren.push_back(child);
hasDependency(child, key)) {
getSubtreesContainingNonMarginalizables(
child, marginalizableKeys, cache, dependentSubtrees);
}
}
} }
} }
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache);
} }
/** /**
* Gather all subtrees that contain non-marginalizable variables in its root. * A helper function for the above gatherDependentCliques().
*/ */
static void getSubtreesContainingNonMarginalizables( static std::set<sharedClique> gatherDependentCliquesFromChildren(
const sharedClique& rootClique, const std::vector<sharedClique>& dependentChildren,
const std::set<Key>& marginalizableKeys, const std::set<Key>& marginalizableKeys,
CachedSearch* cache, CachedSearch* cache) {
std::set<sharedClique>* subtreesContainingNonMarginalizables) { std::deque<sharedClique> descendants(
// If the root clique itself contains non-marginalizable variables, we dependentChildren.begin(), dependentChildren.end());
// just add it to subtreesContainingNonMarginalizables; std::set<sharedClique> dependentCliques;
if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { while (!descendants.empty()) {
subtreesContainingNonMarginalizables->insert(rootClique); sharedClique descendant = descendants.front();
return; descendants.pop_front();
}
// Otherwise, we need to recursively check the children // If the subtree rooted at this descendant contains non-marginalizables,
for (const sharedClique& child : rootClique->children) { // it must lie on a path from the root clique to a clique containing
getSubtreesContainingNonMarginalizables( // non-marginalizables at the leaf side.
child, marginalizableKeys, cache, if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
subtreesContainingNonMarginalizables); dependentCliques.insert(descendant);
}
// Add all children of the current descendant to the set descendants.
for (const sharedClique& child : descendant->children) {
descendants.push_back(child);
}
} }
return dependentCliques;
} }
/** /**
@ -282,28 +283,6 @@ public:
} }
} }
/**
* Add all frontal variables from a subtree to a key set.
*
* @param[in] subRoot Root clique of the subtree
* @param[out] additionalKeys Pointer to the output key set
*/
static void addSubtreeToKeySet(
const sharedClique& subRoot,
std::set<Key>* additionalKeys) {
std::set<sharedClique> cliques;
cliques.insert(subRoot);
while(!cliques.empty()) {
auto begin = cliques.begin();
sharedClique clique = *begin;
cliques.erase(begin);
addCliqueToKeySet(clique, additionalKeys);
for (const sharedClique& child : clique->children) {
cliques.insert(child);
}
}
}
/** /**
* Check if the clique depends on the given key. * Check if the clique depends on the given key.
* *
@ -322,6 +301,19 @@ public:
return false; return false;
} }
} }
/**
* Check if the clique depends on any of the given keys.
*/
static bool hasDependency(
const sharedClique& clique, const std::set<Key>& keys) {
for (Key key : keys) {
if (hasDependency(clique, key)) {
return true;
}
}
return false;
}
}; };
// BayesTreeMarginalizationHelper // BayesTreeMarginalizationHelper