diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 2037dd951..2b63be660 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -121,25 +121,13 @@ namespace gtsam { static DecisionTreeFactor ProductAndNormalize( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif DecisionTreeFactor product = factors.product(); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif // Max over all the potentials by pretending all keys are frontal: auto normalizer = product.max(product.size()); -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif // Normalize the product factor to prevent underflow. product = product / (*normalizer); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif return product; } @@ -230,13 +218,7 @@ namespace gtsam { DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); -#endif DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); -#endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -246,14 +228,8 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); -#endif auto conditional = std::make_shared(product, *sum, orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); -#endif return {conditional, sum}; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index cf00a2209..1ad0cdaf4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,7 +20,6 @@ #include #include -#include #include #include #include @@ -257,6 +256,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } +/** + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return TableFactor + */ +static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { + // PRODUCT: multiply all factors +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteProduct); +#endif + TableFactor product; + for (auto &&factor : factors) { + if (factor) { + if (auto f = std::dynamic_pointer_cast(factor)) { + product = product * (*f); + } else if (auto dtf = + std::dynamic_pointer_cast(factor)) { + product = TableFactor(product * (*dtf)); + } + } + } +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteProduct); +#endif + + // Max over all the potentials by pretending all keys are frontal: + auto normalizer = product.max(product.size()); + +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteNormalize); +#endif + // Normalize the product factor to prevent underflow. + product = product / (*normalizer); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteNormalize); +#endif + + return product; +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -306,13 +347,37 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - // NOTE: This does sum-product. For max-product, use EliminateForMPE. - auto result = EliminateDiscrete(dfg, frontalKeys); + /**** NOTE: This does sum-product. ****/ + // Get product factor + TableFactor product = ProductAndNormalize(dfg); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteSum); +#endif + // All the discrete variables should form a single clique, + // so we can sum out on all the variables as frontals. + // This should give an empty separator. + Ordering orderedKeys(product.keys()); + TableFactor::shared_ptr sum = product.sum(orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteSum); +#endif + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteFormDiscreteConditional); +#endif + // Finally, get the conditional + auto conditional = + std::make_shared(product, *sum, orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteFormDiscreteConditional); +#endif + #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif - return {std::make_shared(result.first), result.second}; + return {std::make_shared(conditional), sum}; } /* ************************************************************************ */ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 31e36101b..1942e9234 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { EXPECT(HybridConditional::CheckInvariants(*result.first, values)); // Check that factor is discrete and correct - auto factor = std::dynamic_pointer_cast(result.second); + auto factor = std::dynamic_pointer_cast(result.second); CHECK(factor); // regression test - EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5)); + EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5)); } /* ************************************************************************* */ @@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) { // Check the remaining factor for x1 CHECK(factor_x1); - auto phi_x1 = std::dynamic_pointer_cast(factor_x1); + auto phi_x1 = std::dynamic_pointer_cast(factor_x1); CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because