From f95ae52aff0385b947a878beb3c3eb6258ddc2dd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:16:00 -0500 Subject: [PATCH 1/5] Use TableFactor everywhere in hybrid elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 703684c78..4fcd420b1 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,12 +20,12 @@ #include #include -#include #include #include #include #include #include +#include #include #include #include @@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, - * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return DecisionTreeFactor::shared_ptr + * @return TableFactor::shared_ptr */ -static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( +static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); - return std::make_shared(discreteKeys, potentials); + return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ From 71ea8c5d4c22289c328c5bdad054037a9793b679 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:55:56 -0500 Subject: [PATCH 2/5] fix tests --- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 6 +++--- gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 31e36101b..1942e9234 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { EXPECT(HybridConditional::CheckInvariants(*result.first, values)); // Check that factor is discrete and correct - auto factor = std::dynamic_pointer_cast(result.second); + auto factor = std::dynamic_pointer_cast(result.second); CHECK(factor); // regression test - EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5)); + EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5)); } /* ************************************************************************* */ @@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) { // Check the remaining factor for x1 CHECK(factor_x1); - auto phi_x1 = std::dynamic_pointer_cast(factor_x1); + auto phi_x1 = std::dynamic_pointer_cast(factor_x1); CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 07f70e95c..6e844dbcb 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents()); // This is now a discreteFactor - auto discreteFactor = dynamic_pointer_cast(factorOnModes); + auto discreteFactor = dynamic_pointer_cast(factorOnModes); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); - EXPECT(discreteFactor->root_->isLeaf() == false); } /**************************************************************************** From 42f8e54c2a31127251401e410ceea932c38127a2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:27:01 -0500 Subject: [PATCH 3/5] customize discrete elimination in Hybrid --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 70 +++++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 4fcd420b1..8b0a2349f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } +/** + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return TableFactor + */ +static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { + // PRODUCT: multiply all factors +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteProduct); +#endif + TableFactor product; + for (const sharedFactor &factor : factors) { + if (factor) { + if (auto f = std::dynamic_pointer_cast(factor)) { + product = product * (*f); + } else if (auto dtf = + std::dynamic_pointer_cast(factor)) { + product = TableFactor(product * (*dtf)); + } + } + } +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteProduct); +#endif + + // Max over all the potentials by pretending all keys are frontal: + auto normalizer = product.max(product.size()); + +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteNormalize); +#endif + // Normalize the product factor to prevent underflow. + product = product / (*normalizer); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteNormalize); +#endif + + return product; +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -299,8 +341,32 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - // NOTE: This does sum-product. For max-product, use EliminateForMPE. - auto result = EliminateDiscrete(dfg, frontalKeys); + /**** NOTE: This does sum-product. ****/ + // Get product factor + TableFactor product = ProductAndNormalize(factors); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteSum); +#endif + // All the discrete variables should form a single clique, + // so we can sum out on all the variables as frontals. + // This should give an empty separator. + Ordering orderedKeys(product.keys()); + DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteSum); +#endif + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteToDiscreteConditional); +#endif + // Finally, get the conditional + auto conditional = + std::make_shared(product, *sum, orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteToDiscreteConditional); +#endif + #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif From 47e76ff03b77a007ca166e07b8aa7a2a042e49e3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:27:20 -0500 Subject: [PATCH 4/5] remove GTSAM_HYBRID_TIMING guards --- gtsam/discrete/DiscreteFactorGraph.cpp | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 2037dd951..2b63be660 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -121,25 +121,13 @@ namespace gtsam { static DecisionTreeFactor ProductAndNormalize( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif DecisionTreeFactor product = factors.product(); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif // Max over all the potentials by pretending all keys are frontal: auto normalizer = product.max(product.size()); -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif // Normalize the product factor to prevent underflow. product = product / (*normalizer); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif return product; } @@ -230,13 +218,7 @@ namespace gtsam { DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); -#endif DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); -#endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -246,14 +228,8 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); -#endif auto conditional = std::make_shared(product, *sum, orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); -#endif return {conditional, sum}; } From 0820fcb7b273f9cf025408cb12b3b8e04ff9a21b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:50:51 -0500 Subject: [PATCH 5/5] fix types --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8b0a2349f..8f63d6e71 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -268,7 +268,7 @@ static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { gttic_(DiscreteProduct); #endif TableFactor product; - for (const sharedFactor &factor : factors) { + for (auto &&factor : factors) { if (factor) { if (auto f = std::dynamic_pointer_cast(factor)) { product = product * (*f); @@ -343,7 +343,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #endif /**** NOTE: This does sum-product. ****/ // Get product factor - TableFactor product = ProductAndNormalize(factors); + TableFactor product = ProductAndNormalize(dfg); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteSum); @@ -352,26 +352,26 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // so we can sum out on all the variables as frontals. // This should give an empty separator. Ordering orderedKeys(product.keys()); - DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys); + TableFactor::shared_ptr sum = product.sum(orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif #if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); + gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional auto conditional = std::make_shared(product, *sum, orderedKeys); #if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); + gttoc_(EliminateDiscreteFormDiscreteConditional); #endif #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif - return {std::make_shared(result.first), result.second}; + return {std::make_shared(conditional), sum}; } /* ************************************************************************ */