Further optimize the implementation of BayesTreeMarginalizationHelper:
Now we won't re-emilinate any unnecessary nodes (we re-emilinated whole subtrees in the previous commits, which is not optimal)release/4.3a0
parent
14c3467520
commit
1a5e711f0e
|
@ -21,6 +21,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include <gtsam/inference/BayesTree.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
@ -50,9 +51,12 @@ public:
|
|||
* 2. Or it has a child node depending on a marginalizable variable AND the
|
||||
* subtree rooted at that child contains non-marginalizables.
|
||||
*
|
||||
* In addition, the subtrees under the aforementioned cliques that require
|
||||
* re-elimination, which contain non-marginalizable variables in their root
|
||||
* node, also need to be re-eliminated.
|
||||
* In addition, for any descendant node depending on a marginalizable
|
||||
* variable, if the subtree rooted at that descendant contains
|
||||
* non-marginalizable variables (i.e., it lies on a path from one of the
|
||||
* aforementioned cliques that require re-elimination to a node containing
|
||||
* non-marginalizable variables at the leaf side), then it also needs to
|
||||
* be re-eliminated.
|
||||
*
|
||||
* @param[in] bayesTree The Bayes tree
|
||||
* @param[in] marginalizableKeys Keys to be marginalized
|
||||
|
@ -66,7 +70,7 @@ public:
|
|||
std::set<Key> additionalKeys;
|
||||
std::set<Key> marginalizableKeySet(
|
||||
marginalizableKeys.begin(), marginalizableKeys.end());
|
||||
std::set<sharedClique> dependentSubtrees;
|
||||
std::set<sharedClique> dependentCliques;
|
||||
CachedSearch cachedSearch;
|
||||
|
||||
// Check each clique that contains a marginalizable key
|
||||
|
@ -77,15 +81,12 @@ public:
|
|||
// Add frontal variables from current clique
|
||||
addCliqueToKeySet(clique, &additionalKeys);
|
||||
|
||||
// Then gather dependent subtrees to be added later
|
||||
gatherDependentSubtrees(
|
||||
clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch);
|
||||
// Then add the dependent cliques
|
||||
for (const sharedClique& dependent :
|
||||
gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
|
||||
addCliqueToKeySet(dependent, &additionalKeys);
|
||||
}
|
||||
}
|
||||
|
||||
// Add the remaining dependent cliques
|
||||
for (const sharedClique& subtree : dependentSubtrees) {
|
||||
addSubtreeToKeySet(subtree, &additionalKeys);
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
|
@ -219,54 +220,54 @@ public:
|
|||
}
|
||||
|
||||
/**
|
||||
* Gather all subtrees that depend on a marginalizable key and contain
|
||||
* non-marginalizable variables in their root.
|
||||
* Gather all dependent nodes that lie on a path from the root clique
|
||||
* to a clique containing a non-marginalizable variable at the leaf side.
|
||||
*
|
||||
* @param[in] rootClique The starting clique
|
||||
* @param[in] rootClique The root clique
|
||||
* @param[in] marginalizableKeys Set of keys to be marginalized
|
||||
* @param[out] dependentSubtrees Pointer to set storing dependent cliques
|
||||
*/
|
||||
static void gatherDependentSubtrees(
|
||||
static std::set<sharedClique> gatherDependentCliques(
|
||||
const sharedClique& rootClique,
|
||||
const std::set<Key>& marginalizableKeys,
|
||||
std::set<sharedClique>* dependentSubtrees,
|
||||
CachedSearch* cache) {
|
||||
for (Key key : rootClique->conditional()->frontals()) {
|
||||
if (marginalizableKeys.count(key)) {
|
||||
// Find children that depend on this key
|
||||
std::vector<sharedClique> dependentChildren;
|
||||
dependentChildren.reserve(rootClique->children.size());
|
||||
for (const sharedClique& child : rootClique->children) {
|
||||
if (!dependentSubtrees->count(child) &&
|
||||
hasDependency(child, key)) {
|
||||
getSubtreesContainingNonMarginalizables(
|
||||
child, marginalizableKeys, cache, dependentSubtrees);
|
||||
}
|
||||
}
|
||||
if (hasDependency(child, marginalizableKeys)) {
|
||||
dependentChildren.push_back(child);
|
||||
}
|
||||
}
|
||||
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gather all subtrees that contain non-marginalizable variables in its root.
|
||||
* A helper function for the above gatherDependentCliques().
|
||||
*/
|
||||
static void getSubtreesContainingNonMarginalizables(
|
||||
const sharedClique& rootClique,
|
||||
static std::set<sharedClique> gatherDependentCliquesFromChildren(
|
||||
const std::vector<sharedClique>& dependentChildren,
|
||||
const std::set<Key>& marginalizableKeys,
|
||||
CachedSearch* cache,
|
||||
std::set<sharedClique>* subtreesContainingNonMarginalizables) {
|
||||
// If the root clique itself contains non-marginalizable variables, we
|
||||
// just add it to subtreesContainingNonMarginalizables;
|
||||
if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) {
|
||||
subtreesContainingNonMarginalizables->insert(rootClique);
|
||||
return;
|
||||
CachedSearch* cache) {
|
||||
std::deque<sharedClique> descendants(
|
||||
dependentChildren.begin(), dependentChildren.end());
|
||||
std::set<sharedClique> dependentCliques;
|
||||
while (!descendants.empty()) {
|
||||
sharedClique descendant = descendants.front();
|
||||
descendants.pop_front();
|
||||
|
||||
// If the subtree rooted at this descendant contains non-marginalizables,
|
||||
// it must lie on a path from the root clique to a clique containing
|
||||
// non-marginalizables at the leaf side.
|
||||
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
|
||||
dependentCliques.insert(descendant);
|
||||
}
|
||||
|
||||
// Otherwise, we need to recursively check the children
|
||||
for (const sharedClique& child : rootClique->children) {
|
||||
getSubtreesContainingNonMarginalizables(
|
||||
child, marginalizableKeys, cache,
|
||||
subtreesContainingNonMarginalizables);
|
||||
// Add all children of the current descendant to the set descendants.
|
||||
for (const sharedClique& child : descendant->children) {
|
||||
descendants.push_back(child);
|
||||
}
|
||||
}
|
||||
return dependentCliques;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add all frontal variables from a clique to a key set.
|
||||
|
@ -282,28 +283,6 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add all frontal variables from a subtree to a key set.
|
||||
*
|
||||
* @param[in] subRoot Root clique of the subtree
|
||||
* @param[out] additionalKeys Pointer to the output key set
|
||||
*/
|
||||
static void addSubtreeToKeySet(
|
||||
const sharedClique& subRoot,
|
||||
std::set<Key>* additionalKeys) {
|
||||
std::set<sharedClique> cliques;
|
||||
cliques.insert(subRoot);
|
||||
while(!cliques.empty()) {
|
||||
auto begin = cliques.begin();
|
||||
sharedClique clique = *begin;
|
||||
cliques.erase(begin);
|
||||
addCliqueToKeySet(clique, additionalKeys);
|
||||
for (const sharedClique& child : clique->children) {
|
||||
cliques.insert(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the clique depends on the given key.
|
||||
*
|
||||
|
@ -322,6 +301,19 @@ public:
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the clique depends on any of the given keys.
|
||||
*/
|
||||
static bool hasDependency(
|
||||
const sharedClique& clique, const std::set<Key>& keys) {
|
||||
for (Key key : keys) {
|
||||
if (hasDependency(clique, key)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
// BayesTreeMarginalizationHelper
|
||||
|
||||
|
|
Loading…
Reference in New Issue