test for computeThreshold

release/4.3a0
Varun Agrawal 2024-12-02 09:53:01 -05:00
parent 8473911926
commit 1091b9cd2d
1 changed files with 37 additions and 8 deletions

View File

@ -140,11 +140,46 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}
namespace pruning_fixture {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor f(A& B& C, "1 5 3 7 2 6 4 8");
DiscreteKey D(4, 2);
DecisionTreeFactor factor(
D& C & B & A,
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
} // namespace pruning_fixture
/* ************************************************************************* */
// Check if computing the correct threshold works.
TEST(DecisionTreeFactor, ComputeThreshold) {
using namespace pruning_fixture;
// Only keep the leaves with the top 5 values.
double threshold = f.computeThreshold(5);
EXPECT_DOUBLES_EQUAL(4.0, threshold, 1e-9);
// Check for more extreme pruning where we only keep the top 2 leaves
threshold = f.computeThreshold(2);
EXPECT_DOUBLES_EQUAL(7.0, threshold, 1e-9);
threshold = factor.computeThreshold(5);
EXPECT_DOUBLES_EQUAL(0.99995287, threshold, 1e-9);
threshold = factor.computeThreshold(3);
EXPECT_DOUBLES_EQUAL(1.0, threshold, 1e-9);
threshold = factor.computeThreshold(6);
EXPECT_DOUBLES_EQUAL(0.61247742, threshold, 1e-9);
}
/* ************************************************************************* */
// Check pruning of the decision tree works as expected.
TEST(DecisionTreeFactor, Prune) {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
using namespace pruning_fixture;
// Only keep the leaves with the top 5 values.
size_t maxNrAssignments = 5;
@ -160,12 +195,6 @@ TEST(DecisionTreeFactor, Prune) {
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
EXPECT(assert_equal(expected2, pruned2));
DiscreteKey D(4, 2);
DecisionTreeFactor factor(
D & C & B & A,
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
DecisionTreeFactor expected3(D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");