add missing tests for DecisionTreeFactor::restrict
parent
e7f758e794
commit
b7d0931fd9
|
@ -172,6 +172,32 @@ TEST(DecisionTreeFactor, enumerate) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test if restricting a factor based on DiscreteValues works.
|
||||||
|
TEST(DecisionTreeFactor, Restrict) {
|
||||||
|
// Test for restricting a single value from multiple values.
|
||||||
|
DiscreteKey A(12, 2), B(5, 3);
|
||||||
|
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||||
|
DiscreteValues fixedValues = {{A.first, 1}};
|
||||||
|
|
||||||
|
DecisionTreeFactor restricted_f1 =
|
||||||
|
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
|
||||||
|
|
||||||
|
DecisionTreeFactor expected_f1(B, "4 5 6");
|
||||||
|
EXPECT(assert_equal(expected_f1, restricted_f1));
|
||||||
|
|
||||||
|
// Test for restricting a multiple value from multiple values.
|
||||||
|
DiscreteKey C(91, 2);
|
||||||
|
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
|
||||||
|
fixedValues = {{A.first, 0}, {B.first, 2}};
|
||||||
|
|
||||||
|
DecisionTreeFactor restricted_f2 =
|
||||||
|
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
|
||||||
|
|
||||||
|
DecisionTreeFactor expected_f2(C, "5 6");
|
||||||
|
EXPECT(assert_equal(expected_f2, restricted_f2));
|
||||||
|
}
|
||||||
|
|
||||||
namespace pruning_fixture {
|
namespace pruning_fixture {
|
||||||
|
|
||||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
|
Loading…
Reference in New Issue