Further optimize the implementation of BayesTreeMarginalizationHelper:
Now we won't re-emilinate any unnecessary nodes (we re-emilinated whole subtrees in the previous commits, which is not optimal)release/4.3a0
							parent
							
								
									14c3467520
								
							
						
					
					
						commit
						1a5e711f0e
					
				|  | @ -21,6 +21,7 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include <unordered_map> | ||||
| #include <deque> | ||||
| #include <gtsam/inference/BayesTree.h> | ||||
| #include <gtsam/inference/BayesTreeCliqueBase.h> | ||||
| #include <gtsam/base/debug.h> | ||||
|  | @ -50,9 +51,12 @@ public: | |||
|    * 2. Or it has a child node depending on a marginalizable variable AND the | ||||
|    *    subtree rooted at that child contains non-marginalizables. | ||||
|    *  | ||||
|    * In addition, the subtrees under the aforementioned cliques that require | ||||
|    * re-elimination, which contain non-marginalizable variables in their root | ||||
|    * node, also need to be re-eliminated. | ||||
|    * In addition, for any descendant node depending on a marginalizable | ||||
|    * variable, if the subtree rooted at that descendant contains | ||||
|    * non-marginalizable variables (i.e., it lies on a path from one of the | ||||
|    * aforementioned cliques that require re-elimination to a node containing | ||||
|    * non-marginalizable variables at the leaf side), then it also needs to | ||||
|    * be re-eliminated. | ||||
|    *  | ||||
|    * @param[in] bayesTree The Bayes tree | ||||
|    * @param[in] marginalizableKeys Keys to be marginalized | ||||
|  | @ -66,7 +70,7 @@ public: | |||
|     std::set<Key> additionalKeys; | ||||
|     std::set<Key> marginalizableKeySet( | ||||
|         marginalizableKeys.begin(), marginalizableKeys.end()); | ||||
|     std::set<sharedClique> dependentSubtrees; | ||||
|     std::set<sharedClique> dependentCliques; | ||||
|     CachedSearch cachedSearch; | ||||
| 
 | ||||
|     // Check each clique that contains a marginalizable key
 | ||||
|  | @ -77,17 +81,14 @@ public: | |||
|         // Add frontal variables from current clique
 | ||||
|         addCliqueToKeySet(clique, &additionalKeys); | ||||
| 
 | ||||
|         // Then gather dependent subtrees to be added later
 | ||||
|         gatherDependentSubtrees( | ||||
|             clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); | ||||
|         // Then add the dependent cliques
 | ||||
|         for (const sharedClique& dependent : | ||||
|              gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { | ||||
|           addCliqueToKeySet(dependent, &additionalKeys); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Add the remaining dependent cliques
 | ||||
|     for (const sharedClique& subtree : dependentSubtrees) { | ||||
|       addSubtreeToKeySet(subtree, &additionalKeys); | ||||
|     } | ||||
| 
 | ||||
|     if (debug) { | ||||
|       std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; | ||||
|       for (const Key& key : additionalKeys) { | ||||
|  | @ -219,53 +220,53 @@ public: | |||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Gather all subtrees that depend on a marginalizable key and contain | ||||
|    * non-marginalizable variables in their root. | ||||
|    * Gather all dependent nodes that lie on a path from the root clique | ||||
|    * to a clique containing a non-marginalizable variable at the leaf side. | ||||
|    * | ||||
|    * @param[in] rootClique The starting clique | ||||
|    * @param[in] rootClique The root clique | ||||
|    * @param[in] marginalizableKeys Set of keys to be marginalized | ||||
|    * @param[out] dependentSubtrees Pointer to set storing dependent cliques | ||||
|    */ | ||||
|   static void gatherDependentSubtrees( | ||||
|   static std::set<sharedClique> gatherDependentCliques( | ||||
|       const sharedClique& rootClique, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|       std::set<sharedClique>* dependentSubtrees, | ||||
|       CachedSearch* cache) { | ||||
|     for (Key key : rootClique->conditional()->frontals()) { | ||||
|       if (marginalizableKeys.count(key)) { | ||||
|         // Find children that depend on this key
 | ||||
|         for (const sharedClique& child : rootClique->children) { | ||||
|           if (!dependentSubtrees->count(child) && | ||||
|               hasDependency(child, key)) { | ||||
|             getSubtreesContainingNonMarginalizables( | ||||
|                 child, marginalizableKeys, cache, dependentSubtrees); | ||||
|           } | ||||
|         } | ||||
|     std::vector<sharedClique> dependentChildren; | ||||
|     dependentChildren.reserve(rootClique->children.size()); | ||||
|     for (const sharedClique& child : rootClique->children) { | ||||
|       if (hasDependency(child, marginalizableKeys)) { | ||||
|         dependentChildren.push_back(child); | ||||
|       } | ||||
|     } | ||||
|     return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Gather all subtrees that contain non-marginalizable variables in its root. | ||||
|    * A helper function for the above gatherDependentCliques(). | ||||
|    */ | ||||
|   static void getSubtreesContainingNonMarginalizables( | ||||
|       const sharedClique& rootClique, | ||||
|   static std::set<sharedClique> gatherDependentCliquesFromChildren( | ||||
|       const std::vector<sharedClique>& dependentChildren, | ||||
|       const std::set<Key>& marginalizableKeys, | ||||
|       CachedSearch* cache, | ||||
|       std::set<sharedClique>* subtreesContainingNonMarginalizables) { | ||||
|     // If the root clique itself contains non-marginalizable variables, we
 | ||||
|     // just add it to subtreesContainingNonMarginalizables;    
 | ||||
|     if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { | ||||
|       subtreesContainingNonMarginalizables->insert(rootClique); | ||||
|       return; | ||||
|     } | ||||
|       CachedSearch* cache) { | ||||
|     std::deque<sharedClique> descendants( | ||||
|         dependentChildren.begin(), dependentChildren.end()); | ||||
|     std::set<sharedClique> dependentCliques; | ||||
|     while (!descendants.empty()) { | ||||
|       sharedClique descendant = descendants.front(); | ||||
|       descendants.pop_front(); | ||||
| 
 | ||||
|     // Otherwise, we need to recursively check the children
 | ||||
|     for (const sharedClique& child : rootClique->children) { | ||||
|       getSubtreesContainingNonMarginalizables( | ||||
|           child, marginalizableKeys, cache, | ||||
|           subtreesContainingNonMarginalizables); | ||||
|       // If the subtree rooted at this descendant contains non-marginalizables,
 | ||||
|       // it must lie on a path from the root clique to a clique containing
 | ||||
|       // non-marginalizables at the leaf side.
 | ||||
|       if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { | ||||
|         dependentCliques.insert(descendant); | ||||
|       } | ||||
| 
 | ||||
|       // Add all children of the current descendant to the set descendants.
 | ||||
|       for (const sharedClique& child : descendant->children) { | ||||
|         descendants.push_back(child); | ||||
|       } | ||||
|     } | ||||
|     return dependentCliques; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -282,28 +283,6 @@ public: | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Add all frontal variables from a subtree to a key set. | ||||
|    * | ||||
|    * @param[in] subRoot Root clique of the subtree | ||||
|    * @param[out] additionalKeys Pointer to the output key set | ||||
|    */ | ||||
|   static void addSubtreeToKeySet( | ||||
|       const sharedClique& subRoot, | ||||
|       std::set<Key>* additionalKeys) { | ||||
|     std::set<sharedClique> cliques; | ||||
|     cliques.insert(subRoot); | ||||
|     while(!cliques.empty()) { | ||||
|       auto begin = cliques.begin(); | ||||
|       sharedClique clique = *begin; | ||||
|       cliques.erase(begin); | ||||
|       addCliqueToKeySet(clique, additionalKeys); | ||||
|       for (const sharedClique& child : clique->children) { | ||||
|         cliques.insert(child); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Check if the clique depends on the given key. | ||||
|    *  | ||||
|  | @ -322,6 +301,19 @@ public: | |||
|       return false; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Check if the clique depends on any of the given keys. | ||||
|    */ | ||||
|   static bool hasDependency( | ||||
|       const sharedClique& clique, const std::set<Key>& keys) { | ||||
|     for (Key key : keys) { | ||||
|       if (hasDependency(clique, key)) { | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
| }; | ||||
| // BayesTreeMarginalizationHelper
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue