add missing tests for DecisionTreeFactor::restrict
parent
e7f758e794
commit
b7d0931fd9
|
@ -172,6 +172,32 @@ TEST(DecisionTreeFactor, enumerate) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test if restricting a factor based on DiscreteValues works.
|
||||
TEST(DecisionTreeFactor, Restrict) {
|
||||
// Test for restricting a single value from multiple values.
|
||||
DiscreteKey A(12, 2), B(5, 3);
|
||||
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||
DiscreteValues fixedValues = {{A.first, 1}};
|
||||
|
||||
DecisionTreeFactor restricted_f1 =
|
||||
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
|
||||
|
||||
DecisionTreeFactor expected_f1(B, "4 5 6");
|
||||
EXPECT(assert_equal(expected_f1, restricted_f1));
|
||||
|
||||
// Test for restricting a multiple value from multiple values.
|
||||
DiscreteKey C(91, 2);
|
||||
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
|
||||
fixedValues = {{A.first, 0}, {B.first, 2}};
|
||||
|
||||
DecisionTreeFactor restricted_f2 =
|
||||
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
|
||||
|
||||
DecisionTreeFactor expected_f2(C, "5 6");
|
||||
EXPECT(assert_equal(expected_f2, restricted_f2));
|
||||
}
|
||||
|
||||
namespace pruning_fixture {
|
||||
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
|
|
Loading…
Reference in New Issue