From dfec8409feb891b522d27c32a414fb4e5bdafc77 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 00:27:04 -0500 Subject: [PATCH] use TableFactor for discrete elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 703684c78..cf00a2209 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -25,7 +25,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -241,18 +243,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, - * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return DecisionTreeFactor::shared_ptr + * @return TableFactor::shared_ptr */ -static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( +static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); - return std::make_shared(discreteKeys, potentials); + return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ @@ -285,12 +287,17 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(ConvertConditionalToTableFactor); #endif - // Convert DiscreteConditional to TableFactor - auto tdc = std::make_shared(*dc); + if (auto dtc = std::dynamic_pointer_cast(dc)) { + /// Get the underlying TableFactor + dfg.push_back(dtc->table()); + } else { + // Convert DiscreteConditional to TableFactor + auto tdc = std::make_shared(*dc); + dfg.push_back(tdc); + } #if GTSAM_HYBRID_TIMING gttoc_(ConvertConditionalToTableFactor); #endif - dfg.push_back(tdc); } else { throwRuntimeError("discreteElimination", f); }