diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 0df46f262..31d256d6f 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -45,9 +45,16 @@ bool HybridBayesTree::equals(const This& other, double tol) const { /* ************************************************************************* */ DiscreteValues HybridBayesTree::discreteMaxProduct( const DiscreteFactorGraph& dfg) const { - TableFactor product = TableProduct(dfg); + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); - DiscreteValues assignment = TableDistribution(product).argmax(); + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + DiscreteValues assignment = TableDistribution(p).argmax(); return assignment; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d7813f1e5..581d027c8 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -255,43 +255,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } -/* ************************************************************************ */ -TableFactor TableProduct(const DiscreteFactorGraph &factors) { - // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif - TableFactor product; - for (auto &&factor : factors) { - if (factor) { - if (auto dtc = std::dynamic_pointer_cast(factor)) { - product = product * dtc->table(); - } else if (auto f = std::dynamic_pointer_cast(factor)) { - product = product * (*f); - } else if (auto dtf = - std::dynamic_pointer_cast(factor)) { - product = product * TableFactor(*dtf); - } - } - } -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif - -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - // Normalize the product factor to prevent underflow. - product = product / *std::dynamic_pointer_cast(denominator); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif - - return product; -} - /* ************************************************************************ */ static DiscreteFactorGraph CollectDiscreteFactors( const HybridGaussianFactorGraph &factors) { @@ -357,17 +320,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // so we can use the TableFactor for efficiency. if (frontalKeys.size() == dfg.keys().size()) { // Get product factor - TableFactor product = TableProduct(dfg); + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif - auto conditional = std::make_shared(product); + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + auto conditional = std::make_shared(p); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif - DiscreteFactor::shared_ptr sum = product.sum(frontalKeys); + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); return {std::make_shared(conditional), sum}; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index b7d815ec6..832ab56a6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -271,13 +271,4 @@ template <> struct traits : public Testable {}; -/** - * @brief Multiply all the `factors` and normalize the - * product to prevent underflow. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return TableFactor - */ -TableFactor TableProduct(const DiscreteFactorGraph& factors); - } // namespace gtsam