diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 703684c78..cf00a2209 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -25,7 +25,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -241,18 +243,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, - * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return DecisionTreeFactor::shared_ptr + * @return TableFactor::shared_ptr */ -static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( +static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); - return std::make_shared(discreteKeys, potentials); + return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ @@ -285,12 +287,17 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(ConvertConditionalToTableFactor); #endif - // Convert DiscreteConditional to TableFactor - auto tdc = std::make_shared(*dc); + if (auto dtc = std::dynamic_pointer_cast(dc)) { + /// Get the underlying TableFactor + dfg.push_back(dtc->table()); + } else { + // Convert DiscreteConditional to TableFactor + auto tdc = std::make_shared(*dc); + dfg.push_back(tdc); + } #if GTSAM_HYBRID_TIMING gttoc_(ConvertConditionalToTableFactor); #endif - dfg.push_back(tdc); } else { throwRuntimeError("discreteElimination", f); }