From dbd0a7d3ba166cc4545f7168aa929e42392b14ad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 8 Jun 2023 09:53:22 -0400 Subject: [PATCH] re-enable DecisionTree tests --- gtsam/discrete/tests/testDecisionTree.cpp | 56 ++++++++++++++++------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index beb100a61..fb49908e2 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -121,7 +121,7 @@ struct Ring { /* ************************************************************************** */ // test DT -TEST_DISABLED(DecisionTree, example) { +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -231,7 +231,7 @@ TEST_DISABLED(DecisionTree, example) { bool bool_of_int(const int& y) { return y != 0; }; typedef DecisionTree StringBoolTree; -TEST_DISABLED(DecisionTree, ConvertValuesOnly) { +TEST(DecisionTree, ConvertValuesOnly) { // Create labels string A("A"), B("B"); @@ -252,7 +252,7 @@ TEST_DISABLED(DecisionTree, ConvertValuesOnly) { enum Label { U, V, X, Y, Z }; typedef DecisionTree LabelBoolTree; -TEST_DISABLED(DecisionTree, ConvertBoth) { +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -279,7 +279,7 @@ TEST_DISABLED(DecisionTree, ConvertBoth) { /* ************************************************************************** */ // test Compose expansion -TEST_DISABLED(DecisionTree, Compose) { +TEST(DecisionTree, Compose) { // Create labels string A("A"), B("B"), C("C"); @@ -305,7 +305,7 @@ TEST_DISABLED(DecisionTree, Compose) { /* ************************************************************************** */ // Check we can create a decision tree of containers. -TEST_DISABLED(DecisionTree, Containers) { +TEST(DecisionTree, Containers) { using Container = std::vector; using StringContainerTree = DecisionTree; @@ -327,7 +327,7 @@ TEST_DISABLED(DecisionTree, Containers) { /* ************************************************************************** */ // Test nrAssignments. -TEST_DISABLED(DecisionTree, NrAssignments) { +TEST(DecisionTree, NrAssignments) { const std::pair A("A", 2), B("B", 2), C("C", 2); DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); @@ -375,7 +375,7 @@ TEST_DISABLED(DecisionTree, NrAssignments) { /* ************************************************************************** */ // Test visit. -TEST_DISABLED(DecisionTree, visit) { +TEST(DecisionTree, visit) { // Create small two-level tree string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); @@ -387,7 +387,7 @@ TEST_DISABLED(DecisionTree, visit) { /* ************************************************************************** */ // Test visit, with Choices argument. -TEST_DISABLED(DecisionTree, visitWith) { +TEST(DecisionTree, visitWith) { // Create small two-level tree string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); @@ -399,7 +399,7 @@ TEST_DISABLED(DecisionTree, visitWith) { /* ************************************************************************** */ // Test visit, with Choices argument. -TEST_DISABLED(DecisionTree, VisitWithPruned) { +TEST(DecisionTree, VisitWithPruned) { // Create pruned tree std::pair A("A", 2), B("B", 2), C("C", 2); std::vector> labels = {C, B, A}; @@ -437,7 +437,7 @@ TEST_DISABLED(DecisionTree, VisitWithPruned) { /* ************************************************************************** */ // Test fold. -TEST_DISABLED(DecisionTree, fold) { +TEST(DecisionTree, fold) { // Create small two-level tree string A("A"), B("B"); DT tree(B, DT(A, 1, 1), DT(A, 2, 3)); @@ -448,7 +448,7 @@ TEST_DISABLED(DecisionTree, fold) { /* ************************************************************************** */ // Test retrieving all labels. -TEST_DISABLED(DecisionTree, labels) { +TEST(DecisionTree, labels) { // Create small two-level tree string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); @@ -458,7 +458,7 @@ TEST_DISABLED(DecisionTree, labels) { /* ************************************************************************** */ // Test unzip method. -TEST_DISABLED(DecisionTree, unzip) { +TEST(DecisionTree, unzip) { using DTP = DecisionTree>; using DT1 = DecisionTree; using DT2 = DecisionTree; @@ -479,7 +479,7 @@ TEST_DISABLED(DecisionTree, unzip) { /* ************************************************************************** */ // Test thresholding. -TEST_DISABLED(DecisionTree, threshold) { +TEST(DecisionTree, threshold) { // Create three level tree const vector keys{DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2)}; @@ -541,7 +541,7 @@ TEST(DecisionTree, ApplyWithAssignment) { /* ************************************************************************** */ // Test number of assignments. -TEST(DecisionTree, NrAssignments2) { +TEST(DecisionTree, Constructor) { using gtsam::symbol_shorthand::M; std::vector probs = {0, 0, 1, 2}; @@ -551,6 +551,30 @@ TEST(DecisionTree, NrAssignments2) { EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); dt1.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);}); + DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; + DecisionTree dt2(keys2, probs); + std::cout << "\n" << std::endl; + dt2.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);}); +} + +/* ************************************************************************** */ +// Test number of assignments. +TEST(DecisionTree, NrAssignments2) { + using gtsam::symbol_shorthand::M; + + std::vector probs = {0, 0, 1, 2}; + + /* Create the decision tree + Choice(m1) + 0 Leaf 0.000000 + 1 Choice(m0) + 1 0 Leaf 1.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; + DecisionTree dt1(keys, probs); + EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); + /* Create the DecisionTree Choice(m1) 0 Choice(m0) @@ -562,9 +586,7 @@ TEST(DecisionTree, NrAssignments2) { */ DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; DecisionTree dt2(keys2, probs); - std::cout << "\n\n" << std::endl; - dt2.print("", DefaultKeyFormatter, [](double x) { return std::to_string(x);}); - // EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); + EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); } /* ************************************************************************* */