add missing tests for DecisionTreeFactor::restrict

release/4.3a0
Varun Agrawal 2025-04-11 15:04:06 -04:00
parent e7f758e794
commit b7d0931fd9
1 changed files with 26 additions and 0 deletions

View File

@ -172,6 +172,32 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Test if restricting a factor based on DiscreteValues works.
TEST(DecisionTreeFactor, Restrict) {
// Test for restricting a single value from multiple values.
DiscreteKey A(12, 2), B(5, 3);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
DiscreteValues fixedValues = {{A.first, 1}};
DecisionTreeFactor restricted_f1 =
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
DecisionTreeFactor expected_f1(B, "4 5 6");
EXPECT(assert_equal(expected_f1, restricted_f1));
// Test for restricting a multiple value from multiple values.
DiscreteKey C(91, 2);
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
fixedValues = {{A.first, 0}, {B.first, 2}};
DecisionTreeFactor restricted_f2 =
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
DecisionTreeFactor expected_f2(C, "5 6");
EXPECT(assert_equal(expected_f2, restricted_f2));
}
namespace pruning_fixture {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);