Some refinement in BayesTreeMarginalizationHelper:
1. Skip subtrees that have already been visited when searching for dependent cliques; 2. Avoid copying shared_ptrs (which needs extra expensive atomic operations) in the searching. Use const Clique* instead of sharedClique whenever possible; 3. Use std::unordered_set instead of std::set to improve average searching speed.release/4.3a0
parent
0d9c3a9958
commit
06dac43cae
|
@ -21,6 +21,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <gtsam/inference/BayesTree.h>
|
#include <gtsam/inference/BayesTree.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
|
@ -62,30 +63,18 @@ public:
|
||||||
* @param[in] marginalizableKeys Keys to be marginalized
|
* @param[in] marginalizableKeys Keys to be marginalized
|
||||||
* @return Set of additional keys that need to be re-eliminated
|
* @return Set of additional keys that need to be re-eliminated
|
||||||
*/
|
*/
|
||||||
static std::set<Key> gatherAdditionalKeysToReEliminate(
|
static std::unordered_set<Key>
|
||||||
|
gatherAdditionalKeysToReEliminate(
|
||||||
const BayesTree& bayesTree,
|
const BayesTree& bayesTree,
|
||||||
const KeyVector& marginalizableKeys) {
|
const KeyVector& marginalizableKeys) {
|
||||||
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
|
const bool debug = ISDEBUG("BayesTreeMarginalizationHelper");
|
||||||
|
|
||||||
std::set<Key> additionalKeys;
|
std::unordered_set<const Clique*> additionalCliques =
|
||||||
std::set<Key> marginalizableKeySet(
|
gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys);
|
||||||
marginalizableKeys.begin(), marginalizableKeys.end());
|
|
||||||
CachedSearch cachedSearch;
|
|
||||||
|
|
||||||
// Check each clique that contains a marginalizable key
|
std::unordered_set<Key> additionalKeys;
|
||||||
for (const sharedClique& clique :
|
for (const Clique* clique : additionalCliques) {
|
||||||
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
|
addCliqueToKeySet(clique, &additionalKeys);
|
||||||
|
|
||||||
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
|
|
||||||
// Add frontal variables from current clique
|
|
||||||
addCliqueToKeySet(clique, &additionalKeys);
|
|
||||||
|
|
||||||
// Then add the dependent cliques
|
|
||||||
for (const sharedClique& dependent :
|
|
||||||
gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) {
|
|
||||||
addCliqueToKeySet(dependent, &additionalKeys);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
|
@ -100,6 +89,41 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/**
|
||||||
|
* This function identifies cliques that need to be re-eliminated before
|
||||||
|
* performing marginalization.
|
||||||
|
* See the docstring of @ref gatherAdditionalKeysToReEliminate().
|
||||||
|
*/
|
||||||
|
static std::unordered_set<const Clique*>
|
||||||
|
gatherAdditionalCliquesToReEliminate(
|
||||||
|
const BayesTree& bayesTree,
|
||||||
|
const KeyVector& marginalizableKeys) {
|
||||||
|
std::unordered_set<const Clique*> additionalCliques;
|
||||||
|
std::unordered_set<Key> marginalizableKeySet(
|
||||||
|
marginalizableKeys.begin(), marginalizableKeys.end());
|
||||||
|
CachedSearch cachedSearch;
|
||||||
|
|
||||||
|
// Check each clique that contains a marginalizable key
|
||||||
|
for (const Clique* clique :
|
||||||
|
getCliquesContainingKeys(bayesTree, marginalizableKeySet)) {
|
||||||
|
if (additionalCliques.count(clique)) {
|
||||||
|
// The clique has already been visited. This can happen when an
|
||||||
|
// ancestor of the current clique also contain some marginalizable
|
||||||
|
// varaibles and it's processed beore the current.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) {
|
||||||
|
// Add the current clique
|
||||||
|
additionalCliques.insert(clique);
|
||||||
|
|
||||||
|
// Then add the dependent cliques
|
||||||
|
gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques,
|
||||||
|
&cachedSearch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return additionalCliques;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gather the cliques containing any of the given keys.
|
* Gather the cliques containing any of the given keys.
|
||||||
|
@ -108,12 +132,12 @@ public:
|
||||||
* @param[in] keysOfInterest Set of keys of interest
|
* @param[in] keysOfInterest Set of keys of interest
|
||||||
* @return Set of cliques that contain any of the given keys
|
* @return Set of cliques that contain any of the given keys
|
||||||
*/
|
*/
|
||||||
static std::set<sharedClique> getCliquesContainingKeys(
|
static std::unordered_set<const Clique*> getCliquesContainingKeys(
|
||||||
const BayesTree& bayesTree,
|
const BayesTree& bayesTree,
|
||||||
const std::set<Key>& keysOfInterest) {
|
const std::unordered_set<Key>& keysOfInterest) {
|
||||||
std::set<sharedClique> cliques;
|
std::unordered_set<const Clique*> cliques;
|
||||||
for (const Key& key : keysOfInterest) {
|
for (const Key& key : keysOfInterest) {
|
||||||
cliques.insert(bayesTree[key]);
|
cliques.insert(bayesTree[key].get());
|
||||||
}
|
}
|
||||||
return cliques;
|
return cliques;
|
||||||
}
|
}
|
||||||
|
@ -122,8 +146,8 @@ public:
|
||||||
* A struct to cache the results of the below two functions.
|
* A struct to cache the results of the below two functions.
|
||||||
*/
|
*/
|
||||||
struct CachedSearch {
|
struct CachedSearch {
|
||||||
std::unordered_map<Clique*, bool> wholeMarginalizableCliques;
|
std::unordered_map<const Clique*, bool> wholeMarginalizableCliques;
|
||||||
std::unordered_map<Clique*, bool> wholeMarginalizableSubtrees;
|
std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -132,10 +156,10 @@ public:
|
||||||
* Note we use a cache map to avoid repeated searches.
|
* Note we use a cache map to avoid repeated searches.
|
||||||
*/
|
*/
|
||||||
static bool isWholeCliqueMarginalizable(
|
static bool isWholeCliqueMarginalizable(
|
||||||
const sharedClique& clique,
|
const Clique* clique,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::unordered_set<Key>& marginalizableKeys,
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
auto it = cache->wholeMarginalizableCliques.find(clique.get());
|
auto it = cache->wholeMarginalizableCliques.find(clique);
|
||||||
if (it != cache->wholeMarginalizableCliques.end()) {
|
if (it != cache->wholeMarginalizableCliques.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
} else {
|
} else {
|
||||||
|
@ -146,7 +170,7 @@ public:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache->wholeMarginalizableCliques.insert({clique.get(), ret});
|
cache->wholeMarginalizableCliques.insert({clique, ret});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -157,17 +181,17 @@ public:
|
||||||
* Note we use a cache map to avoid repeated searches.
|
* Note we use a cache map to avoid repeated searches.
|
||||||
*/
|
*/
|
||||||
static bool isWholeSubtreeMarginalizable(
|
static bool isWholeSubtreeMarginalizable(
|
||||||
const sharedClique& subtree,
|
const Clique* subtree,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::unordered_set<Key>& marginalizableKeys,
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
auto it = cache->wholeMarginalizableSubtrees.find(subtree.get());
|
auto it = cache->wholeMarginalizableSubtrees.find(subtree);
|
||||||
if (it != cache->wholeMarginalizableSubtrees.end()) {
|
if (it != cache->wholeMarginalizableSubtrees.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
} else {
|
} else {
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) {
|
if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) {
|
||||||
for (const sharedClique& child : subtree->children) {
|
for (const sharedClique& child : subtree->children) {
|
||||||
if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
|
if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
|
||||||
ret = false;
|
ret = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -175,7 +199,7 @@ public:
|
||||||
} else {
|
} else {
|
||||||
ret = false;
|
ret = false;
|
||||||
}
|
}
|
||||||
cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret});
|
cache->wholeMarginalizableSubtrees.insert({subtree, ret});
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -189,8 +213,8 @@ public:
|
||||||
* @return true if any variables in the clique need re-elimination
|
* @return true if any variables in the clique need re-elimination
|
||||||
*/
|
*/
|
||||||
static bool needsReelimination(
|
static bool needsReelimination(
|
||||||
const sharedClique& clique,
|
const Clique* clique,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::unordered_set<Key>& marginalizableKeys,
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
bool hasNonMarginalizableAhead = false;
|
bool hasNonMarginalizableAhead = false;
|
||||||
|
|
||||||
|
@ -206,8 +230,8 @@ public:
|
||||||
// Check if any child depends on this marginalizable key and the
|
// Check if any child depends on this marginalizable key and the
|
||||||
// subtree rooted at that child contains non-marginalizables.
|
// subtree rooted at that child contains non-marginalizables.
|
||||||
for (const sharedClique& child : clique->children) {
|
for (const sharedClique& child : clique->children) {
|
||||||
if (hasDependency(child, key) &&
|
if (hasDependency(child.get(), key) &&
|
||||||
!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) {
|
!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -225,47 +249,59 @@ public:
|
||||||
* @param[in] rootClique The root clique
|
* @param[in] rootClique The root clique
|
||||||
* @param[in] marginalizableKeys Set of keys to be marginalized
|
* @param[in] marginalizableKeys Set of keys to be marginalized
|
||||||
*/
|
*/
|
||||||
static std::set<sharedClique> gatherDependentCliques(
|
static void gatherDependentCliques(
|
||||||
const sharedClique& rootClique,
|
const Clique* rootClique,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::unordered_set<Key>& marginalizableKeys,
|
||||||
|
std::unordered_set<const Clique*>* additionalCliques,
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
std::vector<sharedClique> dependentChildren;
|
std::vector<const Clique*> dependentChildren;
|
||||||
dependentChildren.reserve(rootClique->children.size());
|
dependentChildren.reserve(rootClique->children.size());
|
||||||
for (const sharedClique& child : rootClique->children) {
|
for (const sharedClique& child : rootClique->children) {
|
||||||
if (hasDependency(child, marginalizableKeys)) {
|
if (additionalCliques->count(child.get())) {
|
||||||
dependentChildren.push_back(child);
|
// This child has already been visited. This can happen if the
|
||||||
|
// child itself contains a marginalizable variable and it's
|
||||||
|
// processed before the current rootClique.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (hasDependency(child.get(), marginalizableKeys)) {
|
||||||
|
dependentChildren.push_back(child.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache);
|
gatherDependentCliquesFromChildren(
|
||||||
|
dependentChildren, marginalizableKeys, additionalCliques, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A helper function for the above gatherDependentCliques().
|
* A helper function for the above gatherDependentCliques().
|
||||||
*/
|
*/
|
||||||
static std::set<sharedClique> gatherDependentCliquesFromChildren(
|
static void gatherDependentCliquesFromChildren(
|
||||||
const std::vector<sharedClique>& dependentChildren,
|
const std::vector<const Clique*>& dependentChildren,
|
||||||
const std::set<Key>& marginalizableKeys,
|
const std::unordered_set<Key>& marginalizableKeys,
|
||||||
|
std::unordered_set<const Clique*>* additionalCliques,
|
||||||
CachedSearch* cache) {
|
CachedSearch* cache) {
|
||||||
std::deque<sharedClique> descendants(
|
std::deque<const Clique*> descendants(
|
||||||
dependentChildren.begin(), dependentChildren.end());
|
dependentChildren.begin(), dependentChildren.end());
|
||||||
std::set<sharedClique> dependentCliques;
|
|
||||||
while (!descendants.empty()) {
|
while (!descendants.empty()) {
|
||||||
sharedClique descendant = descendants.front();
|
const Clique* descendant = descendants.front();
|
||||||
descendants.pop_front();
|
descendants.pop_front();
|
||||||
|
|
||||||
// If the subtree rooted at this descendant contains non-marginalizables,
|
// If the subtree rooted at this descendant contains non-marginalizables,
|
||||||
// it must lie on a path from the root clique to a clique containing
|
// it must lie on a path from the root clique to a clique containing
|
||||||
// non-marginalizables at the leaf side.
|
// non-marginalizables at the leaf side.
|
||||||
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
|
if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) {
|
||||||
dependentCliques.insert(descendant);
|
additionalCliques->insert(descendant);
|
||||||
}
|
|
||||||
|
|
||||||
// Add all children of the current descendant to the set descendants.
|
// Add children of the current descendant to the set descendants.
|
||||||
for (const sharedClique& child : descendant->children) {
|
for (const sharedClique& child : descendant->children) {
|
||||||
descendants.push_back(child);
|
if (additionalCliques->count(child.get())) {
|
||||||
|
// This child has already been visited.
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
descendants.push_back(child.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return dependentCliques;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -275,8 +311,8 @@ public:
|
||||||
* @param[out] additionalKeys Pointer to the output key set
|
* @param[out] additionalKeys Pointer to the output key set
|
||||||
*/
|
*/
|
||||||
static void addCliqueToKeySet(
|
static void addCliqueToKeySet(
|
||||||
const sharedClique& clique,
|
const Clique* clique,
|
||||||
std::set<Key>* additionalKeys) {
|
std::unordered_set<Key>* additionalKeys) {
|
||||||
for (Key key : clique->conditional()->frontals()) {
|
for (Key key : clique->conditional()->frontals()) {
|
||||||
additionalKeys->insert(key);
|
additionalKeys->insert(key);
|
||||||
}
|
}
|
||||||
|
@ -290,8 +326,8 @@ public:
|
||||||
* @return true if clique depends on the key
|
* @return true if clique depends on the key
|
||||||
*/
|
*/
|
||||||
static bool hasDependency(
|
static bool hasDependency(
|
||||||
const sharedClique& clique, Key key) {
|
const Clique* clique, Key key) {
|
||||||
auto conditional = clique->conditional();
|
auto& conditional = clique->conditional();
|
||||||
if (std::find(conditional->beginParents(),
|
if (std::find(conditional->beginParents(),
|
||||||
conditional->endParents(), key)
|
conditional->endParents(), key)
|
||||||
!= conditional->endParents()) {
|
!= conditional->endParents()) {
|
||||||
|
@ -305,12 +341,15 @@ public:
|
||||||
* Check if the clique depends on any of the given keys.
|
* Check if the clique depends on any of the given keys.
|
||||||
*/
|
*/
|
||||||
static bool hasDependency(
|
static bool hasDependency(
|
||||||
const sharedClique& clique, const std::set<Key>& keys) {
|
const Clique* clique, const std::unordered_set<Key>& keys) {
|
||||||
for (Key key : keys) {
|
auto& conditional = clique->conditional();
|
||||||
if (hasDependency(clique, key)) {
|
for (auto it = conditional->beginParents();
|
||||||
|
it != conditional->endParents(); ++it) {
|
||||||
|
if (keys.count(*it)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update(
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::set<Key> additionalKeys =
|
std::unordered_set<Key> additionalKeys =
|
||||||
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
|
BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate(
|
||||||
isam_, marginalizableKeys);
|
isam_, marginalizableKeys);
|
||||||
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());
|
KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end());
|
||||||
|
|
Loading…
Reference in New Issue