test case for edge case of restrict

release/4.3a0
Varun Agrawal 2025-04-11 15:13:45 -04:00
parent 1b82faeb00
commit d2d1e2c7b0
1 changed files with 13 additions and 0 deletions

View File

@ -196,6 +196,19 @@ TEST(DecisionTreeFactor, Restrict) {
DecisionTreeFactor expected_f2(C, "5 6");
EXPECT(assert_equal(expected_f2, restricted_f2));
// Edge case of restricting a single value when it is the only value.
DecisionTreeFactor f3(A, "50 100");
fixedValues = {{A.first, 1}}; // select 100
DecisionTreeFactor restricted_f3 =
*std::static_pointer_cast<DecisionTreeFactor>(f3.restrict(fixedValues));
EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size());
// There should only be 1 value which is 100
EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues());
EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves());
EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9);
}
namespace pruning_fixture {