diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index d2a94ddc3..c9973a3e9 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -329,6 +330,9 @@ TEST(DecisionTree, Containers) { TEST(DecisionTree, NrAssignments) { const std::pair A("A", 2), B("B", 2), C("C", 2); DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); + + EXPECT_LONGS_EQUAL(8, tree.nrAssignments()); + EXPECT(tree.root_->isLeaf()); auto leaf = std::dynamic_pointer_cast(tree.root_); EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); @@ -348,6 +352,8 @@ TEST(DecisionTree, NrAssignments) { 1 1 Leaf 5 */ + EXPECT_LONGS_EQUAL(8, tree2.nrAssignments()); + auto root = std::dynamic_pointer_cast(tree2.root_); CHECK(root); auto choice0 = std::dynamic_pointer_cast(root->branches()[0]); @@ -531,6 +537,23 @@ TEST(DecisionTree, ApplyWithAssignment) { EXPECT_LONGS_EQUAL(5, count); } +/* ************************************************************************** */ +// Test number of assignments. +TEST(DecisionTree, NrAssignments2) { + using gtsam::symbol_shorthand::M; + + DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; + std::vector probs = {0, 0, 1, 2}; + DecisionTree dt1(keys, probs); + + EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); + + DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; + DecisionTree dt2(keys2, probs); + //TODO(Varun) The below is failing, because the number of assignments aren't being set correctly. + EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); +} + /* ************************************************************************* */ int main() { TestResult tr;