Further optimize the implementation of BayesTreeMarginalizationHelper:
Now we won't re-emilinate any unnecessary nodes (we re-emilinated whole subtrees in the previous commits, which is not optimal)release/4.3a0
parent
14c3467520
commit
1a5e711f0e
|
@ -21,6 +21,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <deque>
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
|
@ -50,9 +51,12 @@ public:
|
||||||
* 2. Or it has a child node depending on a marginalizable variable AND the
|
* 2. Or it has a child node depending on a marginalizable variable AND the
|
||||||
* subtree rooted at that child contains non-marginalizables.
|
* subtree rooted at that child contains non-marginalizables.
|
||||||
*
|
*
|
||||||
* In addition, the subtrees under the aforementioned cliques that require
|
* In addition, for any descendant node depending on a marginalizable
|
||||||
* re-elimination, which contain non-marginalizable variables in their root
|
* variable, if the subtree rooted at that descendant contains
|
||||||
* node, also need to be re-eliminated.
|
* non-marginalizable variables (i.e., it lies on a path from one of the
|
||||||
|
* aforementioned cliques that require re-elimination to a node containing
|
||||||
|
* non-marginalizable variables at the leaf side), then it also needs to
|
||||||
|
* be re-eliminated.
|
||||||
*
|
*
|
||||||
* @param[in] bayesTree The Bayes tree
|
* @param[in] bayesTree The Bayes tree
|
||||||
* @param[in] marginalizableKeys Keys to be marginalized
|
* @param[in] marginalizableKeys Keys to be marginalized
|
||||||
|
@ -66,7 +70,7 @@ public:
|
||||||
std::set<Key> additionalKeys;
|
std::set<Key> additionalKeys;
|
||||||
std::set<Key> marginalizableKeySet(
|
std::set<Key> marginalizableKeySet(
|
||||||
marginalizableKeys.begin(), marginalizableKeys.end());
|
marginalizableKeys.begin(), marginalizableKeys.end());
|
||||||
std::set<sharedClique> dependentSubtrees;
|
std::set<sharedClique> dependentCliques;
|
||||||
CachedSearch cachedSearch;
|
CachedSearch cachedSearch;
|
||||||
|
|
||||||
// Check each clique that contains a marginalizable key
|
// Check each clique that contains a marginalizable key
|
||||||
|
@ -77,17 +81,14 @@ public:
|
||||||
// Add frontal variables from current clique
|
// Add frontal variables from current clique
|
||||||
addCliqueToKeySet(clique, &additionalKeys);
|
addCliqueToKeySet(clique, &additionalKeys);
|
||||||
|
|
||||||
// Then gather dependent subtrees to be added later
|
// Then add the dependent cliques
|
||||||
gatherDependentSubtrees(
|
for (const sharedClique& dependent :
|
||||||
clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch);
|
gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
|
||||||
|
addCliqueToKeySet(dependent, &additionalKeys);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the remaining dependent cliques
|
|
||||||
for (const sharedClique& subtree : dependentSubtrees) {
|
|
||||||
addSubtreeToKeySet(subtree, &additionalKeys);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: ";
|
std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: ";
|
||||||
for (const Key& key : additionalKeys) {
|
for (const Key& key : additionalKeys) {
|
||||||
|
@ -219,53 +220,53 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gather all subtrees that depend on a marginalizable key and contain
|
* Gather all dependent nodes that lie on a path from the root clique
|
||||||
* non-marginalizable variables in their root.
|
* to a clique containing a non-marginalizable variable at the leaf side.
|
||||||
*
|
*
|
||||||
* @param[in] rootClique The starting clique
|
* @param[in] rootClique The root clique
|
||||||
* @param[in] marginalizableKeys Set of keys to be marginalized
|
* @param[in] marginalizableKeys Set of keys to be marginalized
|
||||||
* @param[out] dependentSubtrees Pointer to set storing dependent cliques
|
|
||||||
*/
|
*/
|
||||||
static void gatherDependentSubtrees(
|
static std::set<sharedClique> gatherDependentCliques(
|
||||||
const sharedClique& rootClique,
|
const sharedClique& rootClique,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::set<Key>& marginalizableKeys,
|
||||||
std::set<sharedClique>* dependentSubtrees,
|
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
for (Key key : rootClique->conditional()->frontals()) {
|
std::vector<sharedClique> dependentChildren;
|
||||||
if (marginalizableKeys.count(key)) {
|
dependentChildren.reserve(rootClique->children.size());
|
||||||
// Find children that depend on this key
|
for (const sharedClique& child : rootClique->children) {
|
||||||
for (const sharedClique& child : rootClique->children) {
|
if (hasDependency(child, marginalizableKeys)) {
|
||||||
if (!dependentSubtrees->count(child) &&
|
dependentChildren.push_back(child);
|
||||||
hasDependency(child, key)) {
|
|
||||||
getSubtreesContainingNonMarginalizables(
|
|
||||||
child, marginalizableKeys, cache, dependentSubtrees);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gather all subtrees that contain non-marginalizable variables in its root.
|
* A helper function for the above gatherDependentCliques().
|
||||||
*/
|
*/
|
||||||
static void getSubtreesContainingNonMarginalizables(
|
static std::set<sharedClique> gatherDependentCliquesFromChildren(
|
||||||
const sharedClique& rootClique,
|
const std::vector<sharedClique>& dependentChildren,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::set<Key>& marginalizableKeys,
|
||||||
CachedSearch* cache,
|
CachedSearch* cache) {
|
||||||
std::set<sharedClique>* subtreesContainingNonMarginalizables) {
|
std::deque<sharedClique> descendants(
|
||||||
// If the root clique itself contains non-marginalizable variables, we
|
dependentChildren.begin(), dependentChildren.end());
|
||||||
// just add it to subtreesContainingNonMarginalizables;
|
std::set<sharedClique> dependentCliques;
|
||||||
if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) {
|
while (!descendants.empty()) {
|
||||||
subtreesContainingNonMarginalizables->insert(rootClique);
|
sharedClique descendant = descendants.front();
|
||||||
return;
|
descendants.pop_front();
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we need to recursively check the children
|
// If the subtree rooted at this descendant contains non-marginalizables,
|
||||||
for (const sharedClique& child : rootClique->children) {
|
// it must lie on a path from the root clique to a clique containing
|
||||||
getSubtreesContainingNonMarginalizables(
|
// non-marginalizables at the leaf side.
|
||||||
child, marginalizableKeys, cache,
|
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
|
||||||
subtreesContainingNonMarginalizables);
|
dependentCliques.insert(descendant);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all children of the current descendant to the set descendants.
|
||||||
|
for (const sharedClique& child : descendant->children) {
|
||||||
|
descendants.push_back(child);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return dependentCliques;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -282,28 +283,6 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Add all frontal variables from a subtree to a key set.
|
|
||||||
*
|
|
||||||
* @param[in] subRoot Root clique of the subtree
|
|
||||||
* @param[out] additionalKeys Pointer to the output key set
|
|
||||||
*/
|
|
||||||
static void addSubtreeToKeySet(
|
|
||||||
const sharedClique& subRoot,
|
|
||||||
std::set<Key>* additionalKeys) {
|
|
||||||
std::set<sharedClique> cliques;
|
|
||||||
cliques.insert(subRoot);
|
|
||||||
while(!cliques.empty()) {
|
|
||||||
auto begin = cliques.begin();
|
|
||||||
sharedClique clique = *begin;
|
|
||||||
cliques.erase(begin);
|
|
||||||
addCliqueToKeySet(clique, additionalKeys);
|
|
||||||
for (const sharedClique& child : clique->children) {
|
|
||||||
cliques.insert(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if the clique depends on the given key.
|
* Check if the clique depends on the given key.
|
||||||
*
|
*
|
||||||
|
@ -322,6 +301,19 @@ public:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the clique depends on any of the given keys.
|
||||||
|
*/
|
||||||
|
static bool hasDependency(
|
||||||
|
const sharedClique& clique, const std::set<Key>& keys) {
|
||||||
|
for (Key key : keys) {
|
||||||
|
if (hasDependency(clique, key)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// BayesTreeMarginalizationHelper
|
// BayesTreeMarginalizationHelper
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue