diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index a41d06c2b..756a0cebe 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -140,11 +140,46 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +namespace pruning_fixture { + +DiscreteKey A(1, 2), B(2, 2), C(3, 2); +DecisionTreeFactor f(A& B& C, "1 5 3 7 2 6 4 8"); + +DiscreteKey D(4, 2); +DecisionTreeFactor factor( + D& C & B & A, + "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " + "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); + +} // namespace pruning_fixture + +/* ************************************************************************* */ +// Check if computing the correct threshold works. +TEST(DecisionTreeFactor, ComputeThreshold) { + using namespace pruning_fixture; + + // Only keep the leaves with the top 5 values. + double threshold = f.computeThreshold(5); + EXPECT_DOUBLES_EQUAL(4.0, threshold, 1e-9); + + // Check for more extreme pruning where we only keep the top 2 leaves + threshold = f.computeThreshold(2); + EXPECT_DOUBLES_EQUAL(7.0, threshold, 1e-9); + + threshold = factor.computeThreshold(5); + EXPECT_DOUBLES_EQUAL(0.99995287, threshold, 1e-9); + + threshold = factor.computeThreshold(3); + EXPECT_DOUBLES_EQUAL(1.0, threshold, 1e-9); + + threshold = factor.computeThreshold(6); + EXPECT_DOUBLES_EQUAL(0.61247742, threshold, 1e-9); +} + /* ************************************************************************* */ // Check pruning of the decision tree works as expected. TEST(DecisionTreeFactor, Prune) { - DiscreteKey A(1, 2), B(2, 2), C(3, 2); - DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + using namespace pruning_fixture; // Only keep the leaves with the top 5 values. size_t maxNrAssignments = 5; @@ -160,12 +195,6 @@ TEST(DecisionTreeFactor, Prune) { DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); EXPECT(assert_equal(expected2, pruned2)); - DiscreteKey D(4, 2); - DecisionTreeFactor factor( - D & C & B & A, - "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " - "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); - DecisionTreeFactor expected3(D & C & B & A, "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " "0.999952870000 1.0 1.0 1.0 1.0");