diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c157a2543..c338bb86f 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -308,7 +308,7 @@ TEST(DecisionTree, Containers) { StringContainerTree tree; // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); // Check conversion @@ -324,7 +324,7 @@ TEST(DecisionTree, Containers) { // Test visit. TEST(DecisionTree, visit) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); double sum = 0.0; auto visitor = [&](int y) { sum += y; }; @@ -336,7 +336,7 @@ TEST(DecisionTree, visit) { // Test visit, with Choices argument. TEST(DecisionTree, visitWith) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); double sum = 0.0; auto visitor = [&](const Assignment& choices, int y) { sum += y; }; @@ -348,7 +348,7 @@ TEST(DecisionTree, visitWith) { // Test fold. TEST(DecisionTree, fold) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); auto add = [](const int& y, double x) { return y + x; }; double sum = tree.fold(add, 0.0); @@ -359,14 +359,14 @@ TEST(DecisionTree, fold) { // Test retrieving all labels. TEST(DecisionTree, labels) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); auto labels = tree.labels(); EXPECT_LONGS_EQUAL(2, labels.size()); } /* ******************************************************************************** */ -// Test retrieving all labels. +// Test unzip method. TEST(DecisionTree, unzip) { using DTP = DecisionTree>; using DT1 = DecisionTree; @@ -390,6 +390,29 @@ TEST(DecisionTree, unzip) { EXPECT(tree2.equals(dt2)); } +/* ************************************************************************** */ +// Test thresholding. +TEST(DecisionTree, threshold) { + // Create three level tree + vector keys; + keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); + DT tree(keys, "0 1 2 3 4 5 6 7"); + + // Check number of elements equal to zero + auto count = [](const int& value, int count) { + return value == 0 ? count + 1 : count; + }; + EXPECT_LONGS_EQUAL(1, tree.fold(count, 0)); + + // Now threshold + auto threshold = [](int value) { return value < 5 ? 0 : value; }; + DT thresholded(tree, threshold); + + // Check number of elements equal to zero now = 5 + // TODO(frank): it is 2, because the pruned branches are counted as 1! + EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0)); +} + /* ************************************************************************* */ int main() { TestResult tr;