switch between TableFactor and DecisionTreeFactor
							parent
							
								
									90fd416e5d
								
							
						
					
					
						commit
						1bdaa5062c
					
				|  | @ -50,6 +50,8 @@ | |||
| #include <utility> | ||||
| #include <vector> | ||||
| 
 | ||||
| #define GTSAM_HYBRID_WITH_TABLEFACTOR 0 | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
 | ||||
|  | @ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors( | |||
|   double min_log = errors.min(); | ||||
|   AlgebraicDecisionTree<Key> potentials( | ||||
|       errors, [&min_log](const double x) { return exp(-(x - min_log)); }); | ||||
| #if GTSAM_HYBRID_WITH_TABLEFACTOR | ||||
|   return std::make_shared<TableFactor>(discreteKeys, potentials); | ||||
| #else | ||||
|   return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
|  | @ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors( | |||
|         /// Get the underlying TableFactor
 | ||||
|         dfg.push_back(dtc->table()); | ||||
|       } else { | ||||
| #if GTSAM_HYBRID_WITH_TABLEFACTOR | ||||
|         // Convert DiscreteConditional to TableFactor
 | ||||
|         auto tdc = std::make_shared<TableFactor>(*dc); | ||||
|         dfg.push_back(tdc); | ||||
| #else | ||||
|         dfg.push_back(dc); | ||||
| #endif | ||||
|       } | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|       gttoc_(ConvertConditionalToTableFactor); | ||||
|  | @ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors( | |||
| static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> | ||||
| discreteElimination(const HybridGaussianFactorGraph &factors, | ||||
|                     const Ordering &frontalKeys) { | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(CollectDiscreteFactors); | ||||
| #endif | ||||
|   DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(CollectDiscreteFactors); | ||||
| #endif | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(EliminateDiscrete); | ||||
| #endif | ||||
| #if GTSAM_HYBRID_WITH_TABLEFACTOR | ||||
|   // Check if separator is empty.
 | ||||
|   // This is the same as checking if the number of frontal variables
 | ||||
|   // is the same as the number of variables in the DiscreteFactorGraph.
 | ||||
|  | @ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors, | |||
|     // Get product factor
 | ||||
|     DiscreteFactor::shared_ptr product = dfg.scaledProduct(); | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttic_(EliminateDiscreteFormDiscreteConditional); | ||||
| #endif | ||||
|     // Check type of product, and get as TableFactor for efficiency.
 | ||||
|     // Use object instead of pointer since we need it
 | ||||
|     // for the TableDistribution constructor.
 | ||||
|  | @ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors, | |||
|     } | ||||
|     auto conditional = std::make_shared<TableDistribution>(p); | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|     gttoc_(EliminateDiscreteFormDiscreteConditional); | ||||
| #endif | ||||
| 
 | ||||
|     DiscreteFactor::shared_ptr sum = p.sum(frontalKeys); | ||||
| 
 | ||||
|     return {std::make_shared<HybridConditional>(conditional), sum}; | ||||
| 
 | ||||
|   } else { | ||||
| #endif | ||||
|     // Perform sum-product.
 | ||||
|     auto result = EliminateDiscrete(dfg, frontalKeys); | ||||
|     return {std::make_shared<HybridConditional>(result.first), result.second}; | ||||
| #if GTSAM_HYBRID_WITH_TABLEFACTOR | ||||
|   } | ||||
| #endif | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(EliminateDiscrete); | ||||
| #endif | ||||
|  | @ -411,8 +424,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor( | |||
|       throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); | ||||
|     } | ||||
|   }; | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(HybridCreateGaussianFactor); | ||||
| #endif | ||||
|   DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults, | ||||
|                                                         correct); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(HybridCreateGaussianFactor); | ||||
| #endif | ||||
| 
 | ||||
|   return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors); | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue