Update BayesTreeMarginalizationHelper:
1. Refactor code in BayesTreeMarginalizationHelper; 2. And avoid the unnecessary re-elimination of subtrees that only contain marginalizable variables;release/4.3a0
parent
c6ba2b5fd8
commit
67b0b78ea1
|
@ -20,6 +20,7 @@
|
||||||
// \callgraph
|
// \callgraph
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
|
@ -37,109 +38,54 @@ public:
|
||||||
using Clique = typename BayesTree::Clique;
|
using Clique = typename BayesTree::Clique;
|
||||||
using sharedClique = typename BayesTree::sharedClique;
|
using sharedClique = typename BayesTree::sharedClique;
|
||||||
|
|
||||||
/** Get the additional keys that need reelimination when marginalizing
|
/**
|
||||||
* the variables in @p marginalizableKeys from the Bayes tree @p bayesTree.
|
* This function identifies variables that need to be re-eliminated before
|
||||||
|
* performing marginalization.
|
||||||
*
|
*
|
||||||
* @param[in] bayesTree The Bayes tree.
|
* Re-elimination is necessary for a clique containing marginalizable
|
||||||
* @param[in] marginalizableKeys The keys to be marginalized.
|
* variables if:
|
||||||
*
|
*
|
||||||
|
* 1. Some non-marginalizable variables appear before marginalizable ones
|
||||||
|
* in that clique;
|
||||||
|
* 2. Or it has a child node depending on a marginalizable variable AND the
|
||||||
|
* subtree rooted at that child contains non-marginalizables.
|
||||||
*
|
*
|
||||||
* When marginalizing a variable @f$ \theta @f$ from a Bayes tree, some
|
* In addition, the subtrees under the aforementioned cliques that require
|
||||||
* nodes may need reelimination to ensure the variables to marginalize
|
* re-elimination, which contain non-marginalizable variables in their root
|
||||||
* be eliminated first.
|
* node, also need to be re-eliminated.
|
||||||
*
|
|
||||||
* We should consider two cases:
|
|
||||||
*
|
|
||||||
* 1. If a child node relies on @f$ \theta @f$ (i.e., @f$ \theta @f$
|
|
||||||
* is a parent / separator of the node), then the frontal
|
|
||||||
* variables of the child node need to be reeliminated. In
|
|
||||||
* addition, all the descendants of the child node also need to
|
|
||||||
* be reeliminated.
|
|
||||||
*
|
|
||||||
* 2. If other frontal variables in the same node with @f$ \theta @f$
|
|
||||||
* are in front of @f$ \theta @f$ but not to be marginalized, then
|
|
||||||
* these variables also need to be reeliminated.
|
|
||||||
*
|
|
||||||
* These variables were eliminated before @f$ \theta @f$ in the original
|
|
||||||
* Bayes tree, and after reelimination they will be eliminated after
|
|
||||||
* @f$ \theta @f$ so that @f$ \theta @f$ can be marginalized safely.
|
|
||||||
*
|
*
|
||||||
|
* @param[in] bayesTree The Bayes tree
|
||||||
|
* @param[in] marginalizableKeys Keys to be marginalized
|
||||||
|
* @return Set of additional keys that need to be re-eliminated
|
||||||
*/
|
*/
|
||||||
static void gatherAdditionalKeysToReEliminate(
|
static std::set<Key> gatherAdditionalKeysToReEliminate(
|
||||||
const BayesTree& bayesTree,
|
const BayesTree& bayesTree,
|
||||||
const KeyVector& marginalizableKeys,
|
const KeyVector& marginalizableKeys) {
|
||||||
std::set<Key>& additionalKeys) {
|
|
||||||
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
|
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
|
||||||
|
|
||||||
std::set<Key> marginalizableKeySet(marginalizableKeys.begin(), marginalizableKeys.end());
|
std::set<Key> additionalKeys;
|
||||||
std::set<sharedClique> checkedCliques;
|
std::set<Key> marginalizableKeySet(
|
||||||
|
marginalizableKeys.begin(), marginalizableKeys.end());
|
||||||
|
std::set<sharedClique> dependentSubtrees;
|
||||||
|
CachedSearch cachedSearch;
|
||||||
|
|
||||||
std::set<sharedClique> dependentCliques;
|
// Check each clique that contains a marginalizable key
|
||||||
for (const Key& key : marginalizableKeySet) {
|
for (const sharedClique& clique :
|
||||||
sharedClique clique = bayesTree[key];
|
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
|
||||||
if (checkedCliques.count(clique)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
checkedCliques.insert(clique);
|
|
||||||
|
|
||||||
bool need_reeliminate = false;
|
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
|
||||||
bool has_non_marginalizable_ahead = false;
|
// Add frontal variables from current clique
|
||||||
for (Key i: clique->conditional()->frontals()) {
|
addCliqueToKeySet(clique, &additionalKeys);
|
||||||
if (marginalizableKeySet.count(i)) {
|
|
||||||
if (has_non_marginalizable_ahead) {
|
|
||||||
// Case 2 in the docstring
|
|
||||||
need_reeliminate = true;
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
// Check whether there's a child node dependent on this key.
|
|
||||||
for(const sharedClique& child: clique->children) {
|
|
||||||
if (std::find(child->conditional()->beginParents(),
|
|
||||||
child->conditional()->endParents(), i)
|
|
||||||
!= child->conditional()->endParents()) {
|
|
||||||
// Case 1 in the docstring
|
|
||||||
need_reeliminate = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
has_non_marginalizable_ahead = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!need_reeliminate) {
|
// Then gather dependent subtrees to be added later
|
||||||
// No variable needs to be reeliminated
|
gatherDependentSubtrees(
|
||||||
continue;
|
clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch);
|
||||||
} else {
|
|
||||||
// Need to reeliminate the current clique and all its children
|
|
||||||
// that rely on a marginalizable key.
|
|
||||||
for (Key i: clique->conditional()->frontals()) {
|
|
||||||
additionalKeys.insert(i);
|
|
||||||
for (const sharedClique& child: clique->children) {
|
|
||||||
if (!dependentCliques.count(child) &&
|
|
||||||
std::find(child->conditional()->beginParents(),
|
|
||||||
child->conditional()->endParents(), i)
|
|
||||||
!= child->conditional()->endParents()) {
|
|
||||||
dependentCliques.insert(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recursively add the dependent keys
|
// Add the remaining dependent cliques
|
||||||
while (!dependentCliques.empty()) {
|
for (const sharedClique& subtree : dependentSubtrees) {
|
||||||
auto begin = dependentCliques.begin();
|
addSubtreeToKeySet(subtree, &additionalKeys);
|
||||||
sharedClique clique = *begin;
|
|
||||||
dependentCliques.erase(begin);
|
|
||||||
|
|
||||||
for (Key key : clique->conditional()->frontals()) {
|
|
||||||
additionalKeys.insert(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const sharedClique& child: clique->children) {
|
|
||||||
dependentCliques.insert(child);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
|
@ -149,6 +95,232 @@ public:
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return additionalKeys;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gather the cliques containing any of the given keys.
|
||||||
|
*
|
||||||
|
* @param[in] bayesTree The Bayes tree
|
||||||
|
* @param[in] keysOfInterest Set of keys of interest
|
||||||
|
* @return Set of cliques that contain any of the given keys
|
||||||
|
*/
|
||||||
|
static std::set<sharedClique> getCliquesContainingKeys(
|
||||||
|
const BayesTree& bayesTree,
|
||||||
|
const std::set<Key>& keysOfInterest) {
|
||||||
|
std::set<sharedClique> cliques;
|
||||||
|
for (const Key& key : keysOfInterest) {
|
||||||
|
cliques.insert(bayesTree[key]);
|
||||||
|
}
|
||||||
|
return cliques;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A struct to cache the results of the below two functions.
|
||||||
|
*/
|
||||||
|
struct CachedSearch {
|
||||||
|
std::unordered_map<Clique*, bool> wholeMarginalizableCliques;
|
||||||
|
std::unordered_map<Clique*, bool> wholeMarginalizableSubtrees;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if all variables in the clique are marginalizable.
|
||||||
|
*
|
||||||
|
* Note we use a cache map to avoid repeated searches.
|
||||||
|
*/
|
||||||
|
static bool isWholeCliqueMarginalizable(
|
||||||
|
const sharedClique& clique,
|
||||||
|
const std::set<Key>& marginalizableKeys,
|
||||||
|
CachedSearch* cache) {
|
||||||
|
auto it = cache->wholeMarginalizableCliques.find(clique.get());
|
||||||
|
if (it != cache->wholeMarginalizableCliques.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
bool ret = true;
|
||||||
|
for (Key key : clique->conditional()->frontals()) {
|
||||||
|
if (!marginalizableKeys.count(key)) {
|
||||||
|
ret = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache->wholeMarginalizableCliques.insert({clique.get(), ret});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if all variables in the subtree are marginalizable.
|
||||||
|
*
|
||||||
|
* Note we use a cache map to avoid repeated searches.
|
||||||
|
*/
|
||||||
|
static bool isWholeSubtreeMarginalizable(
|
||||||
|
const sharedClique& subtree,
|
||||||
|
const std::set<Key>& marginalizableKeys,
|
||||||
|
CachedSearch* cache) {
|
||||||
|
auto it = cache->wholeMarginalizableSubtrees.find(subtree.get());
|
||||||
|
if (it != cache->wholeMarginalizableSubtrees.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
bool ret = true;
|
||||||
|
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) {
|
||||||
|
for (const sharedClique& child : subtree->children) {
|
||||||
|
if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
|
||||||
|
ret = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret = false;
|
||||||
|
}
|
||||||
|
cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a clique contains variables that need reelimination due to
|
||||||
|
* elimination ordering conflicts.
|
||||||
|
*
|
||||||
|
* @param[in] clique The clique to check
|
||||||
|
* @param[in] marginalizableKeys Set of keys to be marginalized
|
||||||
|
* @return true if any variables in the clique need re-elimination
|
||||||
|
*/
|
||||||
|
static bool needsReelimination(
|
||||||
|
const sharedClique& clique,
|
||||||
|
const std::set<Key>& marginalizableKeys,
|
||||||
|
CachedSearch* cache) {
|
||||||
|
bool hasNonMarginalizableAhead = false;
|
||||||
|
|
||||||
|
// Check each frontal variable in order
|
||||||
|
for (Key key : clique->conditional()->frontals()) {
|
||||||
|
if (marginalizableKeys.count(key)) {
|
||||||
|
// If we've seen non-marginalizable variables before this one,
|
||||||
|
// we need to reeliminate
|
||||||
|
if (hasNonMarginalizableAhead) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any child depends on this marginalizable key and the
|
||||||
|
// subtree rooted at that child contains non-marginalizables.
|
||||||
|
for (const sharedClique& child : clique->children) {
|
||||||
|
if (hasDependency(child, key) &&
|
||||||
|
!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
hasNonMarginalizableAhead = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gather all subtrees that depend on a marginalizable key and contain
|
||||||
|
* non-marginalizable variables in their root.
|
||||||
|
*
|
||||||
|
* @param[in] rootClique The starting clique
|
||||||
|
* @param[in] marginalizableKeys Set of keys to be marginalized
|
||||||
|
* @param[out] dependentSubtrees Pointer to set storing dependent cliques
|
||||||
|
*/
|
||||||
|
static void gatherDependentSubtrees(
|
||||||
|
const sharedClique& rootClique,
|
||||||
|
const std::set<Key>& marginalizableKeys,
|
||||||
|
std::set<sharedClique>* dependentSubtrees,
|
||||||
|
CachedSearch* cache) {
|
||||||
|
for (Key key : rootClique->conditional()->frontals()) {
|
||||||
|
if (marginalizableKeys.count(key)) {
|
||||||
|
// Find children that depend on this key
|
||||||
|
for (const sharedClique& child : rootClique->children) {
|
||||||
|
if (!dependentSubtrees->count(child) &&
|
||||||
|
hasDependency(child, key)) {
|
||||||
|
getSubtreesContainingNonMarginalizables(
|
||||||
|
child, marginalizableKeys, cache, dependentSubtrees);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gather all subtrees that contain non-marginalizable variables in its root.
|
||||||
|
*/
|
||||||
|
static void getSubtreesContainingNonMarginalizables(
|
||||||
|
const sharedClique& rootClique,
|
||||||
|
const std::set<Key>& marginalizableKeys,
|
||||||
|
CachedSearch* cache,
|
||||||
|
std::set<sharedClique>* subtreesContainingNonMarginalizables) {
|
||||||
|
// If the root clique itself contains non-marginalizable variables, we
|
||||||
|
// just add it to subtreesContainingNonMarginalizables;
|
||||||
|
if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) {
|
||||||
|
subtreesContainingNonMarginalizables->insert(rootClique);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we need to recursively check the children
|
||||||
|
for (const sharedClique& child : rootClique->children) {
|
||||||
|
getSubtreesContainingNonMarginalizables(
|
||||||
|
child, marginalizableKeys, cache,
|
||||||
|
subtreesContainingNonMarginalizables);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add all frontal variables from a clique to a key set.
|
||||||
|
*
|
||||||
|
* @param[in] clique Clique to add keys from
|
||||||
|
* @param[out] additionalKeys Pointer to the output key set
|
||||||
|
*/
|
||||||
|
static void addCliqueToKeySet(
|
||||||
|
const sharedClique& clique,
|
||||||
|
std::set<Key>* additionalKeys) {
|
||||||
|
for (Key key : clique->conditional()->frontals()) {
|
||||||
|
additionalKeys->insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add all frontal variables from a subtree to a key set.
|
||||||
|
*
|
||||||
|
* @param[in] subRoot Root clique of the subtree
|
||||||
|
* @param[out] additionalKeys Pointer to the output key set
|
||||||
|
*/
|
||||||
|
static void addSubtreeToKeySet(
|
||||||
|
const sharedClique& subRoot,
|
||||||
|
std::set<Key>* additionalKeys) {
|
||||||
|
std::set<sharedClique> cliques;
|
||||||
|
cliques.insert(subRoot);
|
||||||
|
while(!cliques.empty()) {
|
||||||
|
auto begin = cliques.begin();
|
||||||
|
sharedClique clique = *begin;
|
||||||
|
cliques.erase(begin);
|
||||||
|
addCliqueToKeySet(clique, additionalKeys);
|
||||||
|
for (const sharedClique& child : clique->children) {
|
||||||
|
cliques.insert(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the clique depends on the given key.
|
||||||
|
*
|
||||||
|
* @param[in] clique Clique to check
|
||||||
|
* @param[in] key Key to check for dependencies
|
||||||
|
* @return true if clique depends on the key
|
||||||
|
*/
|
||||||
|
static bool hasDependency(
|
||||||
|
const sharedClique& clique, Key key) {
|
||||||
|
auto conditional = clique->conditional();
|
||||||
|
if (std::find(conditional->beginParents(),
|
||||||
|
conditional->endParents(), key)
|
||||||
|
!= conditional->endParents()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// BayesTreeMarginalizationHelper
|
// BayesTreeMarginalizationHelper
|
||||||
|
|
|
@ -120,8 +120,8 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark additional keys between the marginalized keys and the leaves
|
// Mark additional keys between the marginalized keys and the leaves
|
||||||
std::set<Key> additionalKeys;
|
|
||||||
#ifdef GTSAM_OLD_MARGINALIZATION
|
#ifdef GTSAM_OLD_MARGINALIZATION
|
||||||
|
std::set<Key> additionalKeys;
|
||||||
for(Key key: marginalizableKeys) {
|
for(Key key: marginalizableKeys) {
|
||||||
ISAM2Clique::shared_ptr clique = isam_[key];
|
ISAM2Clique::shared_ptr clique = isam_[key];
|
||||||
for(const ISAM2Clique::shared_ptr& child: clique->children) {
|
for(const ISAM2Clique::shared_ptr& child: clique->children) {
|
||||||
|
@ -129,8 +129,9 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
|
std::set<Key> additionalKeys =
|
||||||
isam_, marginalizableKeys, additionalKeys);
|
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
|
||||||
|
isam_, marginalizableKeys);
|
||||||
#endif
|
#endif
|
||||||
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());
|
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue