diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index bc36ec94d..cfe5c4b59 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -50,6 +50,8 @@ #include #include +#define GTSAM_HYBRID_WITH_TABLEFACTOR 0 + namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -253,7 +255,11 @@ static DiscreteFactor::shared_ptr DiscreteFactorFromErrors( double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); +#if GTSAM_HYBRID_WITH_TABLEFACTOR return std::make_shared(discreteKeys, potentials); +#else + return std::make_shared(discreteKeys, potentials); +#endif } /* ************************************************************************ */ @@ -290,9 +296,13 @@ static DiscreteFactorGraph CollectDiscreteFactors( /// Get the underlying TableFactor dfg.push_back(dtc->table()); } else { +#if GTSAM_HYBRID_WITH_TABLEFACTOR // Convert DiscreteConditional to TableFactor auto tdc = std::make_shared(*dc); dfg.push_back(tdc); +#else + dfg.push_back(dc); +#endif } #if GTSAM_HYBRID_TIMING gttoc_(ConvertConditionalToTableFactor); @@ -309,11 +319,18 @@ static DiscreteFactorGraph CollectDiscreteFactors( static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { +#if GTSAM_HYBRID_TIMING + gttic_(CollectDiscreteFactors); +#endif DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); +#if GTSAM_HYBRID_TIMING + gttoc_(CollectDiscreteFactors); +#endif #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif +#if GTSAM_HYBRID_WITH_TABLEFACTOR // Check if separator is empty. // This is the same as checking if the number of frontal variables // is the same as the number of variables in the DiscreteFactorGraph. @@ -323,9 +340,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // Get product factor DiscreteFactor::shared_ptr product = dfg.scaledProduct(); -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteFormDiscreteConditional); -#endif // Check type of product, and get as TableFactor for efficiency. // Use object instead of pointer since we need it // for the TableDistribution constructor. @@ -337,19 +351,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } auto conditional = std::make_shared(p); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteFormDiscreteConditional); -#endif - DiscreteFactor::shared_ptr sum = p.sum(frontalKeys); return {std::make_shared(conditional), sum}; } else { +#endif // Perform sum-product. auto result = EliminateDiscrete(dfg, frontalKeys); return {std::make_shared(result.first), result.second}; +#if GTSAM_HYBRID_WITH_TABLEFACTOR } +#endif #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif @@ -411,8 +424,14 @@ static std::shared_ptr createHybridGaussianFactor( throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); } }; +#if GTSAM_HYBRID_TIMING + gttic_(HybridCreateGaussianFactor); +#endif DecisionTree newFactors(eliminationResults, correct); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridCreateGaussianFactor); +#endif return std::make_shared(discreteSeparator, newFactors); }