From 42f8e54c2a31127251401e410ceea932c38127a2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:27:01 -0500 Subject: [PATCH 01/18] customize discrete elimination in Hybrid --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 70 +++++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 4fcd420b1..8b0a2349f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -255,6 +255,48 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } +/** + * @brief Multiply all the `factors` and normalize the + * product to prevent underflow. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return TableFactor + */ +static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { + // PRODUCT: multiply all factors +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteProduct); +#endif + TableFactor product; + for (const sharedFactor &factor : factors) { + if (factor) { + if (auto f = std::dynamic_pointer_cast(factor)) { + product = product * (*f); + } else if (auto dtf = + std::dynamic_pointer_cast(factor)) { + product = TableFactor(product * (*dtf)); + } + } + } +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteProduct); +#endif + + // Max over all the potentials by pretending all keys are frontal: + auto normalizer = product.max(product.size()); + +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteNormalize); +#endif + // Normalize the product factor to prevent underflow. + product = product / (*normalizer); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteNormalize); +#endif + + return product; +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -299,8 +341,32 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - // NOTE: This does sum-product. For max-product, use EliminateForMPE. - auto result = EliminateDiscrete(dfg, frontalKeys); + /**** NOTE: This does sum-product. ****/ + // Get product factor + TableFactor product = ProductAndNormalize(factors); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteSum); +#endif + // All the discrete variables should form a single clique, + // so we can sum out on all the variables as frontals. + // This should give an empty separator. + Ordering orderedKeys(product.keys()); + DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteSum); +#endif + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteToDiscreteConditional); +#endif + // Finally, get the conditional + auto conditional = + std::make_shared(product, *sum, orderedKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteToDiscreteConditional); +#endif + #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif From 47e76ff03b77a007ca166e07b8aa7a2a042e49e3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:27:20 -0500 Subject: [PATCH 02/18] remove GTSAM_HYBRID_TIMING guards --- gtsam/discrete/DiscreteFactorGraph.cpp | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 2037dd951..2b63be660 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -121,25 +121,13 @@ namespace gtsam { static DecisionTreeFactor ProductAndNormalize( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif DecisionTreeFactor product = factors.product(); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif // Max over all the potentials by pretending all keys are frontal: auto normalizer = product.max(product.size()); -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif // Normalize the product factor to prevent underflow. product = product / (*normalizer); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif return product; } @@ -230,13 +218,7 @@ namespace gtsam { DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); -#endif DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); -#endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -246,14 +228,8 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); -#endif auto conditional = std::make_shared(product, *sum, orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); -#endif return {conditional, sum}; } From 0820fcb7b273f9cf025408cb12b3b8e04ff9a21b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 14:50:51 -0500 Subject: [PATCH 03/18] fix types --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8b0a2349f..8f63d6e71 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -268,7 +268,7 @@ static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { gttic_(DiscreteProduct); #endif TableFactor product; - for (const sharedFactor &factor : factors) { + for (auto &&factor : factors) { if (factor) { if (auto f = std::dynamic_pointer_cast(factor)) { product = product * (*f); @@ -343,7 +343,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #endif /**** NOTE: This does sum-product. ****/ // Get product factor - TableFactor product = ProductAndNormalize(factors); + TableFactor product = ProductAndNormalize(dfg); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteSum); @@ -352,26 +352,26 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // so we can sum out on all the variables as frontals. // This should give an empty separator. Ordering orderedKeys(product.keys()); - DecisionTreeFactor::shared_ptr sum = product.sum(orderedKeys); + TableFactor::shared_ptr sum = product.sum(orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif #if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); + gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional auto conditional = std::make_shared(product, *sum, orderedKeys); #if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); + gttoc_(EliminateDiscreteFormDiscreteConditional); #endif #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif - return {std::make_shared(result.first), result.second}; + return {std::make_shared(conditional), sum}; } /* ************************************************************************ */ From 094b76df2d4c43545353c7ec9f70d6dd4eee0f27 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 17:06:57 -0500 Subject: [PATCH 04/18] fix bug in TableFactor when trying to convert to DecisionTreeFactor --- gtsam/discrete/TableFactor.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index bf9662e34..67ba19d39 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,6 +252,11 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); + // If no keys, then return empty DecisionTreeFactor + if (dkeys.size() == 0) { + return DecisionTreeFactor(dkeys, AlgebraicDecisionTree()); + } + std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); From bf4c0bd72de2fabc760562b5de908cad5ec751fa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 17:07:25 -0500 Subject: [PATCH 05/18] fix creation of DiscreteConditional --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++---- gtsam/hybrid/tests/testGaussianMixture.cpp | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8f63d6e71..831e0ccc2 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -351,8 +351,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // All the discrete variables should form a single clique, // so we can sum out on all the variables as frontals. // This should give an empty separator. - Ordering orderedKeys(product.keys()); - TableFactor::shared_ptr sum = product.sum(orderedKeys); + TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif @@ -361,8 +360,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors, gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto c = product / (*sum); + auto conditional = std::make_shared( + frontalKeys.size(), c.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 14bef5fbb..698c1bbf6 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } + /* ************************************************************************* */ int main() { TestResult tr; From 0e2e8bb8ced3be38e150e72df03869f7ca49632f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 18:23:37 -0500 Subject: [PATCH 06/18] full discrete elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 831e0ccc2..ca971191c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -356,13 +356,17 @@ discreteElimination(const HybridGaussianFactorGraph &factors, gttoc_(EliminateDiscreteSum); #endif + // Ordering keys for the conditional so that frontalKeys are really in front + Ordering orderedKeys; + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional - auto c = product / (*sum); auto conditional = std::make_shared( - frontalKeys.size(), c.toDecisionTreeFactor()); + product.toDecisionTreeFactor(), sum->toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From 73f54083a7aeab7c996f865e4fe77d3ab3cdfb1d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 19:34:46 -0500 Subject: [PATCH 07/18] normalize --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ca971191c..5fd9ddfa6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -364,9 +364,19 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif + DecisionTreeFactor joint; + // Normalize if we have only 1 key + // Needed due to conversion from TableFactor + if (product.discreteKeys().size() == 1) { + joint = DecisionTreeFactor(product.discreteKeys(), + product.toDecisionTreeFactor().normalize()); + } else { + joint = product.toDecisionTreeFactor(); + } + // Finally, get the conditional auto conditional = std::make_shared( - product.toDecisionTreeFactor(), sum->toDecisionTreeFactor(), orderedKeys); + joint, sum->toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From a71008d7fd572b679c7816f19f707d2d2770c3b8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:35:36 -0500 Subject: [PATCH 08/18] new helper constructor for DiscreteConditional --- gtsam/discrete/DiscreteConditional.cpp | 11 ++++++++++- gtsam/discrete/DiscreteConditional.h | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eeb5dca3f..8396b10e0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -24,13 +24,13 @@ #include #include +#include #include #include #include #include #include #include -#include using namespace std; using std::pair; @@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DecisionTreeFactor& f, + const Ordering& orderedKeys) + : BaseFactor(f), BaseConditional(nrFrontals) { + keys_.clear(); + keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); +} + /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, const DiscreteKeys& keys, diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3ec9ae590..549504985 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * @brief Construct from DecisionTreeFactor, + * taking the first `nrFrontals` from `orderedKeys`. + * + * @param nrFrontals The number of frontal variables. + * @param f The DecisionTreeFactor to construct from. + * @param orderedKeys Ordered list of keys involved in the conditional. + */ + DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, + const Ordering& orderedKeys); + /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * `nFrontals` keys as frontals, in the order given. From 57c426a870023400c5a9f8d7bc43f508a6b4c082 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:36:07 -0500 Subject: [PATCH 09/18] simplify discrete conditional computation --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5fd9ddfa6..e48052c19 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -364,19 +364,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif - DecisionTreeFactor joint; - // Normalize if we have only 1 key - // Needed due to conversion from TableFactor - if (product.discreteKeys().size() == 1) { - joint = DecisionTreeFactor(product.discreteKeys(), - product.toDecisionTreeFactor().normalize()); - } else { - joint = product.toDecisionTreeFactor(); - } - - // Finally, get the conditional + auto c = product / (*sum); auto conditional = std::make_shared( - joint, sum->toDecisionTreeFactor(), orderedKeys); + c.toDecisionTreeFactor(), frontalKeys.size(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From ffa40f71012d484da9e55323f6bb7e21eb801934 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:37:04 -0500 Subject: [PATCH 10/18] small fix --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index e48052c19..387a5849f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -366,7 +366,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #endif auto c = product / (*sum); auto conditional = std::make_shared( - c.toDecisionTreeFactor(), frontalKeys.size(), orderedKeys); + frontalKeys.size(), c.toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From ab47adeb1873ba72d9ad2e7bb7cd6a6c94747232 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:45:11 -0500 Subject: [PATCH 11/18] fix empty keys case --- gtsam/discrete/TableFactor.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 67ba19d39..a833e1c5e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,7 +254,11 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { // If no keys, then return empty DecisionTreeFactor if (dkeys.size() == 0) { - return DecisionTreeFactor(dkeys, AlgebraicDecisionTree()); + AlgebraicDecisionTree tree; + if (sparse_table_.size() != 0) { + tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); + } + return DecisionTreeFactor(dkeys, tree); } std::vector table; From e49b40b4c45373a25f36052ba6026d2ceeebf77f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 19:03:00 -0500 Subject: [PATCH 12/18] remove TableFactor check for another day --- gtsam/discrete/TableFactor.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a833e1c5e..bf9662e34 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,15 +252,6 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - // If no keys, then return empty DecisionTreeFactor - if (dkeys.size() == 0) { - AlgebraicDecisionTree tree; - if (sparse_table_.size() != 0) { - tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); - } - return DecisionTreeFactor(dkeys, tree); - } - std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); From bb4ee207b837aa5a613461b6c20a67e2d01ef29d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 19:13:13 -0500 Subject: [PATCH 13/18] custom path for empty separator --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 49 +++++++++++----------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 387a5849f..7cc890dc0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -341,41 +341,40 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - /**** NOTE: This does sum-product. ****/ - // Get product factor - TableFactor product = ProductAndNormalize(dfg); + // Check if separator is empty + Ordering allKeys(dfg.keyVector()); + Ordering separator; + std::set_difference(allKeys.begin(), allKeys.end(), frontalKeys.begin(), + frontalKeys.end(), + std::inserter(separator, separator.begin())); + + // If the separator is empty, we have a clique of all the discrete variables + // so we can use the TableFactor for efficiency. + if (separator.size() == 0) { + // Get product factor + TableFactor product = ProductAndNormalize(dfg); #if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); + gttic_(EliminateDiscreteFormDiscreteConditional); #endif - // All the discrete variables should form a single clique, - // so we can sum out on all the variables as frontals. - // This should give an empty separator. - TableFactor::shared_ptr sum = product.sum(frontalKeys); + auto conditional = std::make_shared( + frontalKeys.size(), product.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); + gttoc_(EliminateDiscreteFormDiscreteConditional); #endif - // Ordering keys for the conditional so that frontalKeys are really in front - Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); - + TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteFormDiscreteConditional); -#endif - auto c = product / (*sum); - auto conditional = std::make_shared( - frontalKeys.size(), c.toDecisionTreeFactor(), orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteFormDiscreteConditional); + gttoc_(EliminateDiscrete); #endif -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscrete); -#endif + return {std::make_shared(conditional), sum}; - return {std::make_shared(conditional), sum}; + } else { + // Perform sum-product. + auto result = EliminateDiscrete(dfg, frontalKeys); + return {std::make_shared(result.first), result.second}; + } } /* ************************************************************************ */ From 2894c957b1942d355b4c014b0a176ae6d592d078 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 19:15:49 -0500 Subject: [PATCH 14/18] clarify TableProduct function --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7cc890dc0..25047bfad 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -256,13 +256,12 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( } /** - * @brief Multiply all the `factors` and normalize the - * product to prevent underflow. + * @brief Multiply all the `factors` using the machinery of the TableFactor. * * @param factors The factors to multiply as a DiscreteFactorGraph. * @return TableFactor */ -static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { +static TableFactor TableProduct(const DiscreteFactorGraph &factors) { // PRODUCT: multiply all factors #if GTSAM_HYBRID_TIMING gttic_(DiscreteProduct); @@ -282,14 +281,13 @@ static TableFactor ProductAndNormalize(const DiscreteFactorGraph &factors) { gttoc_(DiscreteProduct); #endif - // Max over all the potentials by pretending all keys are frontal: - auto normalizer = product.max(product.size()); - #if GTSAM_HYBRID_TIMING gttic_(DiscreteNormalize); #endif + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); // Normalize the product factor to prevent underflow. - product = product / (*normalizer); + product = product / (*denominator); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteNormalize); #endif @@ -352,7 +350,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // so we can use the TableFactor for efficiency. if (separator.size() == 0) { // Get product factor - TableFactor product = ProductAndNormalize(dfg); + TableFactor product = TableProduct(dfg); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); From d22ba290547cb9a81c6e2656bd877bc4cf5a836d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 19:18:05 -0500 Subject: [PATCH 15/18] remove DiscreteConditional constructor since we no longer use it --- gtsam/discrete/DiscreteConditional.cpp | 9 --------- gtsam/discrete/DiscreteConditional.h | 11 ----------- 2 files changed, 20 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 8396b10e0..0eea8b4bd 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -47,15 +47,6 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(size_t nrFrontals, - const DecisionTreeFactor& f, - const Ordering& orderedKeys) - : BaseFactor(f), BaseConditional(nrFrontals) { - keys_.clear(); - keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); -} - /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, const DiscreteKeys& keys, diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 549504985..3ec9ae590 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -56,17 +56,6 @@ class GTSAM_EXPORT DiscreteConditional /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); - /** - * @brief Construct from DecisionTreeFactor, - * taking the first `nrFrontals` from `orderedKeys`. - * - * @param nrFrontals The number of frontal variables. - * @param f The DecisionTreeFactor to construct from. - * @param orderedKeys Ordered list of keys involved in the conditional. - */ - DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, - const Ordering& orderedKeys); - /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * `nFrontals` keys as frontals, in the order given. From 02d99590334f44887f664ddf9f16571a6170f149 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 16:04:20 -0500 Subject: [PATCH 16/18] small fix --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 25047bfad..34ee3de8c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -273,7 +273,7 @@ static TableFactor TableProduct(const DiscreteFactorGraph &factors) { product = product * (*f); } else if (auto dtf = std::dynamic_pointer_cast(factor)) { - product = TableFactor(product * (*dtf)); + product = product * TableFactor(*dtf); } } } From 113492f8b5e5027de95ae064d47cb279cf85a84d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 16:07:46 -0500 Subject: [PATCH 17/18] separate function to collect discrete factors --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 34ee3de8c..7762249f1 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -296,14 +296,14 @@ static TableFactor TableProduct(const DiscreteFactorGraph &factors) { } /* ************************************************************************ */ -static std::pair> -discreteElimination(const HybridGaussianFactorGraph &factors, - const Ordering &frontalKeys) { +static DiscreteFactorGraph CollectDiscreteFactors( + const HybridGaussianFactorGraph &factors) { DiscreteFactorGraph dfg; for (auto &f : factors) { if (auto df = dynamic_pointer_cast(f)) { dfg.push_back(df); + } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute a discrete factor from the remaining error. @@ -336,6 +336,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } } + return dfg; +} + +/* ************************************************************************ */ +static std::pair> +discreteElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); + #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif From 32317d4416bde28ee45e4b986a3bc57cf6b033a9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 2 Jan 2025 16:13:24 -0500 Subject: [PATCH 18/18] simplify empty separator check --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7762249f1..8be5a8af4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -348,16 +348,12 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - // Check if separator is empty - Ordering allKeys(dfg.keyVector()); - Ordering separator; - std::set_difference(allKeys.begin(), allKeys.end(), frontalKeys.begin(), - frontalKeys.end(), - std::inserter(separator, separator.begin())); - + // Check if separator is empty. + // This is the same as checking if the number of frontal variables + // is the same as the number of variables in the DiscreteFactorGraph. // If the separator is empty, we have a clique of all the discrete variables // so we can use the TableFactor for efficiency. - if (separator.size() == 0) { + if (frontalKeys.size() == dfg.keys().size()) { // Get product factor TableFactor product = TableProduct(dfg);