Some refinement in BayesTreeMarginalizationHelper:
1. Skip subtrees that have already been visited when searching for dependent cliques; 2. Avoid copying shared_ptrs (which needs extra expensive atomic operations) in the searching. Use const Clique* instead of sharedClique whenever possible; 3. Use std::unordered_set instead of std::set to improve average searching speed.release/4.3a0
							parent
							
								
									0d9c3a9958
								
							
						
					
					
						commit
						06dac43cae
					
				|  | @ -21,6 +21,7 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <unordered_map> | ||||
| #include <unordered_set> | ||||
| #include <deque> | ||||
| #include <gtsam/inference/BayesTree.h> | ||||
| #include <gtsam/inference/BayesTreeCliqueBase.h> | ||||
|  | @ -62,30 +63,18 @@ public: | |||
|    * @param[in] marginalizableKeys Keys to be marginalized | ||||
|    * @return Set of additional keys that need to be re-eliminated | ||||
|    */ | ||||
|   static std::set<Key> gatherAdditionalKeysToReEliminate( | ||||
|   static std::unordered_set<Key> | ||||
|   gatherAdditionalKeysToReEliminate( | ||||
|       const BayesTree& bayesTree, | ||||
|       const KeyVector& marginalizableKeys) { | ||||
|     const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); | ||||
| 
 | ||||
|     std::set<Key> additionalKeys; | ||||
|     std::set<Key> marginalizableKeySet( | ||||
|         marginalizableKeys.begin(), marginalizableKeys.end()); | ||||
|     CachedSearch cachedSearch; | ||||
|     std::unordered_set<const Clique*> additionalCliques = | ||||
|         gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys);     | ||||
| 
 | ||||
|     // Check each clique that contains a marginalizable key
 | ||||
|     for (const sharedClique& clique : | ||||
|          getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { | ||||
| 
 | ||||
|       if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { | ||||
|         // Add frontal variables from current clique
 | ||||
|         addCliqueToKeySet(clique, &additionalKeys); | ||||
| 
 | ||||
|         // Then add the dependent cliques
 | ||||
|         for (const sharedClique& dependent : | ||||
|              gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { | ||||
|           addCliqueToKeySet(dependent, &additionalKeys); | ||||
|         } | ||||
|       } | ||||
|     std::unordered_set<Key> additionalKeys; | ||||
|     for (const Clique* clique : additionalCliques) { | ||||
|       addCliqueToKeySet(clique, &additionalKeys); | ||||
|     } | ||||
| 
 | ||||
|     if (debug) { | ||||
|  | @ -100,6 +89,41 @@ public: | |||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   /**
 | ||||
|    * This function identifies cliques that need to be re-eliminated before | ||||
|    * performing marginalization. | ||||
|    * See the docstring of @ref gatherAdditionalKeysToReEliminate(). | ||||
|    */ | ||||
|   static std::unordered_set<const Clique*> | ||||
|   gatherAdditionalCliquesToReEliminate( | ||||
|       const BayesTree& bayesTree, | ||||
|       const KeyVector& marginalizableKeys) { | ||||
|     std::unordered_set<const Clique*> additionalCliques; | ||||
|     std::unordered_set<Key> marginalizableKeySet( | ||||
|         marginalizableKeys.begin(), marginalizableKeys.end()); | ||||
|     CachedSearch cachedSearch; | ||||
| 
 | ||||
|     // Check each clique that contains a marginalizable key
 | ||||
|     for (const Clique* clique : | ||||
|          getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { | ||||
|       if (additionalCliques.count(clique)) { | ||||
|         // The clique has already been visited. This can happen when an
 | ||||
|         // ancestor of the current clique also contain some marginalizable
 | ||||
|         // varaibles and it's processed beore the current.
 | ||||
|         continue; | ||||
|       } | ||||
| 
 | ||||
|       if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { | ||||
|         // Add the current clique
 | ||||
|         additionalCliques.insert(clique); | ||||
| 
 | ||||
|         // Then add the dependent cliques
 | ||||
|         gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, | ||||
|                                &cachedSearch); | ||||
|       } | ||||
|     } | ||||
|     return additionalCliques; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Gather the cliques containing any of the given keys. | ||||
|  | @ -108,12 +132,12 @@ public: | |||
|    * @param[in] keysOfInterest Set of keys of interest | ||||
|    * @return Set of cliques that contain any of the given keys | ||||
|    */ | ||||
|   static std::set<sharedClique> getCliquesContainingKeys( | ||||
|   static std::unordered_set<const Clique*> getCliquesContainingKeys( | ||||
|       const BayesTree& bayesTree, | ||||
|       const std::set<Key>& keysOfInterest) { | ||||
|     std::set<sharedClique> cliques; | ||||
|       const std::unordered_set<Key>& keysOfInterest) { | ||||
|     std::unordered_set<const Clique*> cliques; | ||||
|     for (const Key& key : keysOfInterest) { | ||||
|       cliques.insert(bayesTree[key]); | ||||
|       cliques.insert(bayesTree[key].get()); | ||||
|     } | ||||
|     return cliques; | ||||
|   } | ||||
|  | @ -122,8 +146,8 @@ public: | |||
|    * A struct to cache the results of the below two functions. | ||||
|    */ | ||||
|   struct CachedSearch { | ||||
|     std::unordered_map<Clique*, bool> wholeMarginalizableCliques; | ||||
|     std::unordered_map<Clique*, bool> wholeMarginalizableSubtrees; | ||||
|     std::unordered_map<const Clique*, bool> wholeMarginalizableCliques; | ||||
|     std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees; | ||||
|   }; | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -132,10 +156,10 @@ public: | |||
|    * Note we use a cache map to avoid repeated searches. | ||||
|    */ | ||||
|   static bool isWholeCliqueMarginalizable( | ||||
|       const sharedClique& clique, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|       const Clique* clique, | ||||
|       const std::unordered_set<Key>& marginalizableKeys, | ||||
|       CachedSearch* cache) { | ||||
|     auto it = cache->wholeMarginalizableCliques.find(clique.get()); | ||||
|     auto it = cache->wholeMarginalizableCliques.find(clique); | ||||
|     if (it != cache->wholeMarginalizableCliques.end()) { | ||||
|       return it->second; | ||||
|     } else { | ||||
|  | @ -146,7 +170,7 @@ public: | |||
|           break; | ||||
|         } | ||||
|       } | ||||
|       cache->wholeMarginalizableCliques.insert({clique.get(), ret}); | ||||
|       cache->wholeMarginalizableCliques.insert({clique, ret}); | ||||
|       return ret; | ||||
|     } | ||||
|   } | ||||
|  | @ -157,17 +181,17 @@ public: | |||
|    * Note we use a cache map to avoid repeated searches. | ||||
|    */ | ||||
|   static bool isWholeSubtreeMarginalizable( | ||||
|       const sharedClique& subtree, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|       const Clique* subtree, | ||||
|       const std::unordered_set<Key>& marginalizableKeys, | ||||
|       CachedSearch* cache) { | ||||
|     auto it = cache->wholeMarginalizableSubtrees.find(subtree.get()); | ||||
|     auto it = cache->wholeMarginalizableSubtrees.find(subtree); | ||||
|     if (it != cache->wholeMarginalizableSubtrees.end()) { | ||||
|       return it->second; | ||||
|     } else { | ||||
|       bool ret = true; | ||||
|       if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { | ||||
|         for (const sharedClique& child : subtree->children) { | ||||
|           if (!isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { | ||||
|           if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { | ||||
|             ret = false; | ||||
|             break; | ||||
|           } | ||||
|  | @ -175,7 +199,7 @@ public: | |||
|       } else { | ||||
|         ret = false; | ||||
|       } | ||||
|       cache->wholeMarginalizableSubtrees.insert({subtree.get(), ret}); | ||||
|       cache->wholeMarginalizableSubtrees.insert({subtree, ret}); | ||||
|       return ret; | ||||
|     } | ||||
|   } | ||||
|  | @ -189,8 +213,8 @@ public: | |||
|    * @return true if any variables in the clique need re-elimination | ||||
|    */ | ||||
|   static bool needsReelimination( | ||||
|       const sharedClique& clique, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|       const Clique* clique, | ||||
|       const std::unordered_set<Key>& marginalizableKeys, | ||||
|       CachedSearch* cache) { | ||||
|     bool hasNonMarginalizableAhead = false; | ||||
| 
 | ||||
|  | @ -206,8 +230,8 @@ public: | |||
|         // Check if any child depends on this marginalizable key and the
 | ||||
|         // subtree rooted at that child contains non-marginalizables.
 | ||||
|         for (const sharedClique& child : clique->children) { | ||||
|           if (hasDependency(child, key) && | ||||
|               !isWholeSubtreeMarginalizable(child, marginalizableKeys, cache)) { | ||||
|           if (hasDependency(child.get(), key) && | ||||
|               !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { | ||||
|             return true; | ||||
|           } | ||||
|         } | ||||
|  | @ -225,47 +249,59 @@ public: | |||
|    * @param[in] rootClique The root clique | ||||
|    * @param[in] marginalizableKeys Set of keys to be marginalized | ||||
|    */ | ||||
|   static std::set<sharedClique> gatherDependentCliques( | ||||
|       const sharedClique& rootClique, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|   static void gatherDependentCliques( | ||||
|       const Clique* rootClique, | ||||
|       const std::unordered_set<Key>& marginalizableKeys, | ||||
|       std::unordered_set<const Clique*>* additionalCliques, | ||||
|       CachedSearch* cache) { | ||||
|     std::vector<sharedClique> dependentChildren; | ||||
|     std::vector<const Clique*> dependentChildren; | ||||
|     dependentChildren.reserve(rootClique->children.size()); | ||||
|     for (const sharedClique& child : rootClique->children) { | ||||
|       if (hasDependency(child, marginalizableKeys)) { | ||||
|         dependentChildren.push_back(child); | ||||
|       if (additionalCliques->count(child.get())) { | ||||
|         // This child has already been visited. This can happen if the
 | ||||
|         // child itself contains a marginalizable variable and it's
 | ||||
|         // processed before the current rootClique.
 | ||||
|         continue; | ||||
|       } | ||||
|       if (hasDependency(child.get(), marginalizableKeys)) { | ||||
|         dependentChildren.push_back(child.get()); | ||||
|       } | ||||
|     } | ||||
|     return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); | ||||
|     gatherDependentCliquesFromChildren( | ||||
|         dependentChildren, marginalizableKeys, additionalCliques, cache); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * A helper function for the above gatherDependentCliques(). | ||||
|    */ | ||||
|   static std::set<sharedClique> gatherDependentCliquesFromChildren( | ||||
|       const std::vector<sharedClique>& dependentChildren, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|   static void gatherDependentCliquesFromChildren( | ||||
|       const std::vector<const Clique*>& dependentChildren, | ||||
|       const std::unordered_set<Key>& marginalizableKeys, | ||||
|       std::unordered_set<const Clique*>* additionalCliques, | ||||
|       CachedSearch* cache) { | ||||
|     std::deque<sharedClique> descendants( | ||||
|     std::deque<const Clique*> descendants( | ||||
|         dependentChildren.begin(), dependentChildren.end()); | ||||
|     std::set<sharedClique> dependentCliques; | ||||
|     while (!descendants.empty()) { | ||||
|       sharedClique descendant = descendants.front(); | ||||
|       const Clique* descendant = descendants.front(); | ||||
|       descendants.pop_front(); | ||||
| 
 | ||||
|       // If the subtree rooted at this descendant contains non-marginalizables,
 | ||||
|       // it must lie on a path from the root clique to a clique containing
 | ||||
|       // non-marginalizables at the leaf side.
 | ||||
|       if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { | ||||
|         dependentCliques.insert(descendant); | ||||
|       } | ||||
|         additionalCliques->insert(descendant); | ||||
| 
 | ||||
|       // Add all children of the current descendant to the set descendants.
 | ||||
|       for (const sharedClique& child : descendant->children) { | ||||
|         descendants.push_back(child); | ||||
|         // Add children of the current descendant to the set descendants.
 | ||||
|         for (const sharedClique& child : descendant->children) { | ||||
|           if (additionalCliques->count(child.get())) { | ||||
|             // This child has already been visited.
 | ||||
|             continue; | ||||
|           } else { | ||||
|             descendants.push_back(child.get()); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     return dependentCliques; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -275,8 +311,8 @@ public: | |||
|    * @param[out] additionalKeys Pointer to the output key set | ||||
|    */ | ||||
|   static void addCliqueToKeySet( | ||||
|       const sharedClique& clique, | ||||
|       std::set<Key>* additionalKeys) { | ||||
|       const Clique* clique, | ||||
|       std::unordered_set<Key>* additionalKeys) { | ||||
|     for (Key key : clique->conditional()->frontals()) { | ||||
|       additionalKeys->insert(key); | ||||
|     } | ||||
|  | @ -290,8 +326,8 @@ public: | |||
|    * @return true if clique depends on the key | ||||
|    */ | ||||
|   static bool hasDependency( | ||||
|       const sharedClique& clique, Key key) { | ||||
|     auto conditional = clique->conditional(); | ||||
|       const Clique* clique, Key key) { | ||||
|     auto& conditional = clique->conditional(); | ||||
|     if (std::find(conditional->beginParents(), | ||||
|         conditional->endParents(), key) | ||||
|         != conditional->endParents()) { | ||||
|  | @ -305,12 +341,15 @@ public: | |||
|    * Check if the clique depends on any of the given keys. | ||||
|    */ | ||||
|   static bool hasDependency( | ||||
|       const sharedClique& clique, const std::set<Key>& keys) { | ||||
|     for (Key key : keys) { | ||||
|       if (hasDependency(clique, key)) { | ||||
|       const Clique* clique, const std::unordered_set<Key>& keys) { | ||||
|     auto& conditional = clique->conditional(); | ||||
|     for (auto it = conditional->beginParents(); | ||||
|         it != conditional->endParents(); ++it) { | ||||
|       if (keys.count(*it)) { | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     return false; | ||||
|   } | ||||
| }; | ||||
|  |  | |||
|  | @ -93,7 +93,7 @@ FixedLagSmoother::Result IncrementalFixedLagSmoother::update( | |||
|     std::cout << std::endl; | ||||
|   } | ||||
| 
 | ||||
|   std::set<Key> additionalKeys = | ||||
|   std::unordered_set<Key> additionalKeys = | ||||
|       BayesTreeMarginalizationHelper<ISAM2>::gatherAdditionalKeysToReEliminate( | ||||
|           isam_, marginalizableKeys); | ||||
|   KeyList additionalMarkedKeys(additionalKeys.begin(), additionalKeys.end()); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue