test for computeThreshold
parent
8473911926
commit
1091b9cd2d
|
@ -140,11 +140,46 @@ TEST(DecisionTreeFactor, enumerate) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
namespace pruning_fixture {
|
||||
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
DecisionTreeFactor f(A& B& C, "1 5 3 7 2 6 4 8");
|
||||
|
||||
DiscreteKey D(4, 2);
|
||||
DecisionTreeFactor factor(
|
||||
D& C & B & A,
|
||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||
|
||||
} // namespace pruning_fixture
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check if computing the correct threshold works.
|
||||
TEST(DecisionTreeFactor, ComputeThreshold) {
|
||||
using namespace pruning_fixture;
|
||||
|
||||
// Only keep the leaves with the top 5 values.
|
||||
double threshold = f.computeThreshold(5);
|
||||
EXPECT_DOUBLES_EQUAL(4.0, threshold, 1e-9);
|
||||
|
||||
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||
threshold = f.computeThreshold(2);
|
||||
EXPECT_DOUBLES_EQUAL(7.0, threshold, 1e-9);
|
||||
|
||||
threshold = factor.computeThreshold(5);
|
||||
EXPECT_DOUBLES_EQUAL(0.99995287, threshold, 1e-9);
|
||||
|
||||
threshold = factor.computeThreshold(3);
|
||||
EXPECT_DOUBLES_EQUAL(1.0, threshold, 1e-9);
|
||||
|
||||
threshold = factor.computeThreshold(6);
|
||||
EXPECT_DOUBLES_EQUAL(0.61247742, threshold, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check pruning of the decision tree works as expected.
|
||||
TEST(DecisionTreeFactor, Prune) {
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
using namespace pruning_fixture;
|
||||
|
||||
// Only keep the leaves with the top 5 values.
|
||||
size_t maxNrAssignments = 5;
|
||||
|
@ -160,12 +195,6 @@ TEST(DecisionTreeFactor, Prune) {
|
|||
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||
EXPECT(assert_equal(expected2, pruned2));
|
||||
|
||||
DiscreteKey D(4, 2);
|
||||
DecisionTreeFactor factor(
|
||||
D & C & B & A,
|
||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||
|
||||
DecisionTreeFactor expected3(D & C & B & A,
|
||||
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||
|
|
Loading…
Reference in New Issue