diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 94c9624ca..b365d56ce 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -172,6 +172,32 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Test if restricting a factor based on DiscreteValues works. +TEST(DecisionTreeFactor, Restrict) { + // Test for restricting a single value from multiple values. + DiscreteKey A(12, 2), B(5, 3); + DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); + DiscreteValues fixedValues = {{A.first, 1}}; + + DecisionTreeFactor restricted_f1 = + *std::static_pointer_cast(f1.restrict(fixedValues)); + + DecisionTreeFactor expected_f1(B, "4 5 6"); + EXPECT(assert_equal(expected_f1, restricted_f1)); + + // Test for restricting a multiple value from multiple values. + DiscreteKey C(91, 2); + DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12"); + fixedValues = {{A.first, 0}, {B.first, 2}}; + + DecisionTreeFactor restricted_f2 = + *std::static_pointer_cast(f2.restrict(fixedValues)); + + DecisionTreeFactor expected_f2(C, "5 6"); + EXPECT(assert_equal(expected_f2, restricted_f2)); +} + namespace pruning_fixture { DiscreteKey A(1, 2), B(2, 2), C(3, 2);