From b7d0931fd9efde78c3795b365ddfdcf82e425d12 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 11 Apr 2025 15:04:06 -0400 Subject: [PATCH 1/8] add missing tests for DecisionTreeFactor::restrict --- .../discrete/tests/testDecisionTreeFactor.cpp | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 94c9624ca..b365d56ce 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -172,6 +172,32 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Test if restricting a factor based on DiscreteValues works. +TEST(DecisionTreeFactor, Restrict) { + // Test for restricting a single value from multiple values. + DiscreteKey A(12, 2), B(5, 3); + DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); + DiscreteValues fixedValues = {{A.first, 1}}; + + DecisionTreeFactor restricted_f1 = + *std::static_pointer_cast(f1.restrict(fixedValues)); + + DecisionTreeFactor expected_f1(B, "4 5 6"); + EXPECT(assert_equal(expected_f1, restricted_f1)); + + // Test for restricting a multiple value from multiple values. + DiscreteKey C(91, 2); + DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12"); + fixedValues = {{A.first, 0}, {B.first, 2}}; + + DecisionTreeFactor restricted_f2 = + *std::static_pointer_cast(f2.restrict(fixedValues)); + + DecisionTreeFactor expected_f2(C, "5 6"); + EXPECT(assert_equal(expected_f2, restricted_f2)); +} + namespace pruning_fixture { DiscreteKey A(1, 2), B(2, 2), C(3, 2); From 1b82faeb0009e65b19f61e48384a2ac22421b6ed Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 11 Apr 2025 15:04:26 -0400 Subject: [PATCH 2/8] update restrict to remove restricted keys --- gtsam/discrete/DecisionTreeFactor.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ebc02c1b5..1ba52642c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -548,8 +548,16 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( const DiscreteValues& assignment) const { ADT restricted_tree = ADT::restrict(assignment); - return std::make_shared(this->discreteKeys(), - restricted_tree); + DiscreteKeys restricted_keys = this->discreteKeys(); + for (auto&& kv : assignment) { + Key key = kv.first; + // Remove the key from the keys list + restricted_keys.erase( + std::remove_if(restricted_keys.begin(), restricted_keys.end(), + [key](const DiscreteKey& k) { return k.first == key; }), + restricted_keys.end()); + } + return std::make_shared(restricted_keys, restricted_tree); } /* ************************************************************************ */ From d2d1e2c7b0a7a201bbce7a1a704f0ebb43c63ee1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 11 Apr 2025 15:13:45 -0400 Subject: [PATCH 3/8] test case for edge case of restrict --- gtsam/discrete/tests/testDecisionTreeFactor.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index b365d56ce..52fa7c097 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -196,6 +196,19 @@ TEST(DecisionTreeFactor, Restrict) { DecisionTreeFactor expected_f2(C, "5 6"); EXPECT(assert_equal(expected_f2, restricted_f2)); + + // Edge case of restricting a single value when it is the only value. + DecisionTreeFactor f3(A, "50 100"); + fixedValues = {{A.first, 1}}; // select 100 + + DecisionTreeFactor restricted_f3 = + *std::static_pointer_cast(f3.restrict(fixedValues)); + + EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size()); + // There should only be 1 value which is 100 + EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues()); + EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves()); + EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9); } namespace pruning_fixture { From 4a7cf4cc40eaea503904e0b0bba1725217b7ca5c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 11 Apr 2025 15:19:24 -0400 Subject: [PATCH 4/8] update HybridNonlinearFactorGraph::restrict to ignore empty discrete factors --- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index b42676aac..4f8cf45ee 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -237,12 +237,16 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( if (auto hf = dynamic_pointer_cast(f)) { result.push_back(hf->restrict(discreteValues)); } else if (auto df = dynamic_pointer_cast(f)) { - result.push_back(df->restrict(discreteValues)); + // In the case where all the values have been selected, we ignore the + // factor since it doesn't add any information + if (df->discreteKeys().size() > 0) { + result.push_back(df->restrict(discreteValues)); + } } else { result.push_back(f); // Everything else is just added as is } } - + return result; } From d4920687047dbf8b9fdbc353d993ca9fc85faf4e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 11 Apr 2025 16:46:11 -0400 Subject: [PATCH 5/8] first restrict before checking numer of keys --- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 4f8cf45ee..6ab28ccb4 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -23,7 +23,6 @@ #include #include #include - namespace gtsam { /* ************************************************************************* */ @@ -237,10 +236,11 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( if (auto hf = dynamic_pointer_cast(f)) { result.push_back(hf->restrict(discreteValues)); } else if (auto df = dynamic_pointer_cast(f)) { - // In the case where all the values have been selected, we ignore the - // factor since it doesn't add any information - if (df->discreteKeys().size() > 0) { - result.push_back(df->restrict(discreteValues)); + auto restricted_df = df->restrict(discreteValues); + // In the case where all the discrete values have been selected, + // we ignore the factor since it doesn't add any information + if (restricted_df->discreteKeys().size() > 0) { + result.push_back(restricted_df); } } else { result.push_back(f); // Everything else is just added as is From 50e93b165b1a000ff4cb360f61a06b738f44a589 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 13 Apr 2025 10:44:23 -0400 Subject: [PATCH 6/8] fix docstring --- gtsam/base/timing.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/base/timing.h b/gtsam/base/timing.h index 5a96633db..35a67eba9 100644 --- a/gtsam/base/timing.h +++ b/gtsam/base/timing.h @@ -206,8 +206,8 @@ namespace gtsam { * (CPU time, number of times, wall time, time + children in seconds, min * time, max time) * - * @param addLineBreak Flag indicating if a line break should be added at - * the end. Only used at the top-leve. + * @param addLineBreak Flag indicating if a line break should be + * added at the end. Only used at the top-level. */ GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const; @@ -217,8 +217,8 @@ namespace gtsam { * (CPU time, number of times, wall time, time + children in seconds, min * time, max time) * - * @param addLineBreak Flag indicating if a line break should be added at - * the end. Only used at the top-leve. + * @param addLineBreak Flag indicating if a line break should be + * added at the end. Only used at the top-level. */ GTSAM_EXPORT void printCsv(bool addLineBreak = false) const; From e2cb8c7f4465e50f76c68ccc116098d8dcee911f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 13 Apr 2025 10:44:29 -0400 Subject: [PATCH 7/8] formatting --- gtsam/hybrid/HybridBayesNet.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0840cb381..08bab93ec 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @note If marginal greater than this threshold, the mode gets assigned that * value and is considered "dead" for hybrid elimination. The mode can then be * removed since it only has a single possible assignment. - + * * @return A pruned HybridBayesNet */ HybridBayesNet prune(size_t maxNrLeaves, From e30839bac9230ace0369d95d20f9d9d95b1fcc23 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 13 Apr 2025 13:00:05 -0400 Subject: [PATCH 8/8] update comments --- gtsam/discrete/DecisionTreeFactor.cpp | 4 ++++ gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ba52642c..39d678ad2 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -548,6 +548,9 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( const DiscreteValues& assignment) const { ADT restricted_tree = ADT::restrict(assignment); + // Get all the keys that are not restricted by the assignment + // This ensures that the new restricted factor doesn't have keys + // for which the information has been removed. DiscreteKeys restricted_keys = this->discreteKeys(); for (auto&& kv : assignment) { Key key = kv.first; @@ -557,6 +560,7 @@ DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( [key](const DiscreteKey& k) { return k.first == key; }), restricted_keys.end()); } + // Create the restricted factor with the appropriate keys and tree. return std::make_shared(restricted_keys, restricted_tree); } diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 6ab28ccb4..11dd6a2a5 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -237,8 +237,11 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( result.push_back(hf->restrict(discreteValues)); } else if (auto df = dynamic_pointer_cast(f)) { auto restricted_df = df->restrict(discreteValues); - // In the case where all the discrete values have been selected, - // we ignore the factor since it doesn't add any information + // In the case where all the discrete values in the factor + // have been selected, we get a factor without any keys, + // and default values of 0.5. + // Since this factor no longer adds any information, we ignore it to make + // inference faster. if (restricted_df->discreteKeys().size() > 0) { result.push_back(restricted_df); }