diff --git a/gtsam/base/timing.h b/gtsam/base/timing.h index 5a96633db..35a67eba9 100644 --- a/gtsam/base/timing.h +++ b/gtsam/base/timing.h @@ -206,8 +206,8 @@ namespace gtsam { * (CPU time, number of times, wall time, time + children in seconds, min * time, max time) * - * @param addLineBreak Flag indicating if a line break should be added at - * the end. Only used at the top-leve. + * @param addLineBreak Flag indicating if a line break should be + * added at the end. Only used at the top-level. */ GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const; @@ -217,8 +217,8 @@ namespace gtsam { * (CPU time, number of times, wall time, time + children in seconds, min * time, max time) * - * @param addLineBreak Flag indicating if a line break should be added at - * the end. Only used at the top-leve. + * @param addLineBreak Flag indicating if a line break should be + * added at the end. Only used at the top-level. */ GTSAM_EXPORT void printCsv(bool addLineBreak = false) const; diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ebc02c1b5..39d678ad2 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -548,8 +548,20 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( const DiscreteValues& assignment) const { ADT restricted_tree = ADT::restrict(assignment); - return std::make_shared(this->discreteKeys(), - restricted_tree); + // Get all the keys that are not restricted by the assignment + // This ensures that the new restricted factor doesn't have keys + // for which the information has been removed. + DiscreteKeys restricted_keys = this->discreteKeys(); + for (auto&& kv : assignment) { + Key key = kv.first; + // Remove the key from the keys list + restricted_keys.erase( + std::remove_if(restricted_keys.begin(), restricted_keys.end(), + [key](const DiscreteKey& k) { return k.first == key; }), + restricted_keys.end()); + } + // Create the restricted factor with the appropriate keys and tree. + return std::make_shared(restricted_keys, restricted_tree); } /* ************************************************************************ */ diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 94c9624ca..52fa7c097 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -172,6 +172,45 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Test if restricting a factor based on DiscreteValues works. +TEST(DecisionTreeFactor, Restrict) { + // Test for restricting a single value from multiple values. + DiscreteKey A(12, 2), B(5, 3); + DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); + DiscreteValues fixedValues = {{A.first, 1}}; + + DecisionTreeFactor restricted_f1 = + *std::static_pointer_cast(f1.restrict(fixedValues)); + + DecisionTreeFactor expected_f1(B, "4 5 6"); + EXPECT(assert_equal(expected_f1, restricted_f1)); + + // Test for restricting a multiple value from multiple values. + DiscreteKey C(91, 2); + DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12"); + fixedValues = {{A.first, 0}, {B.first, 2}}; + + DecisionTreeFactor restricted_f2 = + *std::static_pointer_cast(f2.restrict(fixedValues)); + + DecisionTreeFactor expected_f2(C, "5 6"); + EXPECT(assert_equal(expected_f2, restricted_f2)); + + // Edge case of restricting a single value when it is the only value. + DecisionTreeFactor f3(A, "50 100"); + fixedValues = {{A.first, 1}}; // select 100 + + DecisionTreeFactor restricted_f3 = + *std::static_pointer_cast(f3.restrict(fixedValues)); + + EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size()); + // There should only be 1 value which is 100 + EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues()); + EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves()); + EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9); +} + namespace pruning_fixture { DiscreteKey A(1, 2), B(2, 2), C(3, 2); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0840cb381..08bab93ec 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @note If marginal greater than this threshold, the mode gets assigned that * value and is considered "dead" for hybrid elimination. The mode can then be * removed since it only has a single possible assignment. - + * * @return A pruned HybridBayesNet */ HybridBayesNet prune(size_t maxNrLeaves, diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index b42676aac..11dd6a2a5 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -23,7 +23,6 @@ #include #include #include - namespace gtsam { /* ************************************************************************* */ @@ -237,12 +236,20 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( if (auto hf = dynamic_pointer_cast(f)) { result.push_back(hf->restrict(discreteValues)); } else if (auto df = dynamic_pointer_cast(f)) { - result.push_back(df->restrict(discreteValues)); + auto restricted_df = df->restrict(discreteValues); + // In the case where all the discrete values in the factor + // have been selected, we get a factor without any keys, + // and default values of 0.5. + // Since this factor no longer adds any information, we ignore it to make + // inference faster. + if (restricted_df->discreteKeys().size() > 0) { + result.push_back(restricted_df); + } } else { result.push_back(f); // Everything else is just added as is } } - + return result; }