From 094b76df2d4c43545353c7ec9f70d6dd4eee0f27 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 17:06:57 -0500 Subject: [PATCH 1/8] fix bug in TableFactor when trying to convert to DecisionTreeFactor --- gtsam/discrete/TableFactor.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index bf9662e34..67ba19d39 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,6 +252,11 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); + // If no keys, then return empty DecisionTreeFactor + if (dkeys.size() == 0) { + return DecisionTreeFactor(dkeys, AlgebraicDecisionTree()); + } + std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); From bf4c0bd72de2fabc760562b5de908cad5ec751fa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 17:07:25 -0500 Subject: [PATCH 2/8] fix creation of DiscreteConditional --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++---- gtsam/hybrid/tests/testGaussianMixture.cpp | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8f63d6e71..831e0ccc2 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -351,8 +351,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // All the discrete variables should form a single clique, // so we can sum out on all the variables as frontals. // This should give an empty separator. - Ordering orderedKeys(product.keys()); - TableFactor::shared_ptr sum = product.sum(orderedKeys); + TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif @@ -361,8 +360,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors, gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional - auto conditional = - std::make_shared(product, *sum, orderedKeys); + auto c = product / (*sum); + auto conditional = std::make_shared( + frontalKeys.size(), c.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 14bef5fbb..698c1bbf6 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } + /* ************************************************************************* */ int main() { TestResult tr; From 0e2e8bb8ced3be38e150e72df03869f7ca49632f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 18:23:37 -0500 Subject: [PATCH 3/8] full discrete elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 831e0ccc2..ca971191c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -356,13 +356,17 @@ discreteElimination(const HybridGaussianFactorGraph &factors, gttoc_(EliminateDiscreteSum); #endif + // Ordering keys for the conditional so that frontalKeys are really in front + Ordering orderedKeys; + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional - auto c = product / (*sum); auto conditional = std::make_shared( - frontalKeys.size(), c.toDecisionTreeFactor()); + product.toDecisionTreeFactor(), sum->toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From 73f54083a7aeab7c996f865e4fe77d3ab3cdfb1d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 19:34:46 -0500 Subject: [PATCH 4/8] normalize --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ca971191c..5fd9ddfa6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -364,9 +364,19 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif + DecisionTreeFactor joint; + // Normalize if we have only 1 key + // Needed due to conversion from TableFactor + if (product.discreteKeys().size() == 1) { + joint = DecisionTreeFactor(product.discreteKeys(), + product.toDecisionTreeFactor().normalize()); + } else { + joint = product.toDecisionTreeFactor(); + } + // Finally, get the conditional auto conditional = std::make_shared( - product.toDecisionTreeFactor(), sum->toDecisionTreeFactor(), orderedKeys); + joint, sum->toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From a71008d7fd572b679c7816f19f707d2d2770c3b8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:35:36 -0500 Subject: [PATCH 5/8] new helper constructor for DiscreteConditional --- gtsam/discrete/DiscreteConditional.cpp | 11 ++++++++++- gtsam/discrete/DiscreteConditional.h | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eeb5dca3f..8396b10e0 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -24,13 +24,13 @@ #include #include +#include #include #include #include #include #include #include -#include using namespace std; using std::pair; @@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DecisionTreeFactor& f, + const Ordering& orderedKeys) + : BaseFactor(f), BaseConditional(nrFrontals) { + keys_.clear(); + keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); +} + /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, const DiscreteKeys& keys, diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 3ec9ae590..549504985 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * @brief Construct from DecisionTreeFactor, + * taking the first `nrFrontals` from `orderedKeys`. + * + * @param nrFrontals The number of frontal variables. + * @param f The DecisionTreeFactor to construct from. + * @param orderedKeys Ordered list of keys involved in the conditional. + */ + DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, + const Ordering& orderedKeys); + /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * `nFrontals` keys as frontals, in the order given. From 57c426a870023400c5a9f8d7bc43f508a6b4c082 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:36:07 -0500 Subject: [PATCH 6/8] simplify discrete conditional computation --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5fd9ddfa6..e48052c19 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -364,19 +364,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif - DecisionTreeFactor joint; - // Normalize if we have only 1 key - // Needed due to conversion from TableFactor - if (product.discreteKeys().size() == 1) { - joint = DecisionTreeFactor(product.discreteKeys(), - product.toDecisionTreeFactor().normalize()); - } else { - joint = product.toDecisionTreeFactor(); - } - - // Finally, get the conditional + auto c = product / (*sum); auto conditional = std::make_shared( - joint, sum->toDecisionTreeFactor(), orderedKeys); + c.toDecisionTreeFactor(), frontalKeys.size(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From ffa40f71012d484da9e55323f6bb7e21eb801934 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:37:04 -0500 Subject: [PATCH 7/8] small fix --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index e48052c19..387a5849f 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -366,7 +366,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #endif auto c = product / (*sum); auto conditional = std::make_shared( - c.toDecisionTreeFactor(), frontalKeys.size(), orderedKeys); + frontalKeys.size(), c.toDecisionTreeFactor(), orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif From ab47adeb1873ba72d9ad2e7bb7cd6a6c94747232 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 20:45:11 -0500 Subject: [PATCH 8/8] fix empty keys case --- gtsam/discrete/TableFactor.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 67ba19d39..a833e1c5e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,7 +254,11 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { // If no keys, then return empty DecisionTreeFactor if (dkeys.size() == 0) { - return DecisionTreeFactor(dkeys, AlgebraicDecisionTree()); + AlgebraicDecisionTree tree; + if (sparse_table_.size() != 0) { + tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); + } + return DecisionTreeFactor(dkeys, tree); } std::vector table;