test for computeThreshold
parent
8473911926
commit
1091b9cd2d
|
@ -140,11 +140,46 @@ TEST(DecisionTreeFactor, enumerate) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace pruning_fixture {
|
||||||
|
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
DecisionTreeFactor f(A& B& C, "1 5 3 7 2 6 4 8");
|
||||||
|
|
||||||
|
DiscreteKey D(4, 2);
|
||||||
|
DecisionTreeFactor factor(
|
||||||
|
D& C & B & A,
|
||||||
|
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||||
|
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||||
|
|
||||||
|
} // namespace pruning_fixture
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check if computing the correct threshold works.
|
||||||
|
TEST(DecisionTreeFactor, ComputeThreshold) {
|
||||||
|
using namespace pruning_fixture;
|
||||||
|
|
||||||
|
// Only keep the leaves with the top 5 values.
|
||||||
|
double threshold = f.computeThreshold(5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(4.0, threshold, 1e-9);
|
||||||
|
|
||||||
|
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||||
|
threshold = f.computeThreshold(2);
|
||||||
|
EXPECT_DOUBLES_EQUAL(7.0, threshold, 1e-9);
|
||||||
|
|
||||||
|
threshold = factor.computeThreshold(5);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.99995287, threshold, 1e-9);
|
||||||
|
|
||||||
|
threshold = factor.computeThreshold(3);
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, threshold, 1e-9);
|
||||||
|
|
||||||
|
threshold = factor.computeThreshold(6);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.61247742, threshold, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check pruning of the decision tree works as expected.
|
// Check pruning of the decision tree works as expected.
|
||||||
TEST(DecisionTreeFactor, Prune) {
|
TEST(DecisionTreeFactor, Prune) {
|
||||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
using namespace pruning_fixture;
|
||||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
|
||||||
|
|
||||||
// Only keep the leaves with the top 5 values.
|
// Only keep the leaves with the top 5 values.
|
||||||
size_t maxNrAssignments = 5;
|
size_t maxNrAssignments = 5;
|
||||||
|
@ -160,12 +195,6 @@ TEST(DecisionTreeFactor, Prune) {
|
||||||
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||||
EXPECT(assert_equal(expected2, pruned2));
|
EXPECT(assert_equal(expected2, pruned2));
|
||||||
|
|
||||||
DiscreteKey D(4, 2);
|
|
||||||
DecisionTreeFactor factor(
|
|
||||||
D & C & B & A,
|
|
||||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
|
||||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
|
||||||
|
|
||||||
DecisionTreeFactor expected3(D & C & B & A,
|
DecisionTreeFactor expected3(D & C & B & A,
|
||||||
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||||
"0.999952870000 1.0 1.0 1.0 1.0");
|
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||||
|
|
Loading…
Reference in New Issue