diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 10532284c..156177d03 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -53,17 +53,26 @@ namespace gtsam { /** constant stored in this leaf */ Y constant_; + /** The number of assignments contained within this leaf. + * Particularly useful when leaves have been pruned. + */ + size_t nrAssignments_; + /// Default constructor for serialization. Leaf() {} /// Constructor from constant - Leaf(const Y& constant) : constant_(constant) {} + Leaf(const Y& constant, size_t nrAssignments = 1) + : constant_(constant), nrAssignments_(nrAssignments) {} /// Return the constant const Y& constant() const { return constant_; } + /// Return the number of assignments contained within this leaf. + size_t nrAssignments() const { return nrAssignments_; } + /// Leaf-Leaf equality bool sameLeaf(const Leaf& q) const override { return constant_ == q.constant_; @@ -84,7 +93,8 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; + std::cout << s << " Leaf [" << nrAssignments() << "]" + << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ @@ -104,14 +114,14 @@ namespace gtsam { /** apply unary operator */ NodePtr apply(const Unary& op) const override { - NodePtr f(new Leaf(op(constant_))); + NodePtr f(new Leaf(op(constant_), nrAssignments_)); return f; } /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, const Assignment& assignment) const override { - NodePtr f(new Leaf(op(assignment, constant_))); + NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); return f; } @@ -127,7 +137,9 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { // fL op gL - NodePtr h(new Leaf(op(fL.constant_, constant_))); + // TODO(Varun) nrAssignments setting is not correct. + // Depending on f and g, the nrAssignments can be different. This is a bug! + NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments())); return h; } @@ -138,7 +150,7 @@ namespace gtsam { /** choose a branch, create new memory ! */ NodePtr choose(const L& label, size_t index) const override { - return NodePtr(new Leaf(constant())); + return NodePtr(new Leaf(constant(), nrAssignments())); } bool isLeaf() const override { return true; } @@ -153,6 +165,7 @@ namespace gtsam { void serialize(ARCHIVE& ar, const unsigned int /*version*/) { ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_NVP(constant_); + ar& BOOST_SERIALIZATION_NVP(nrAssignments_); } #endif }; // Leaf @@ -222,8 +235,16 @@ namespace gtsam { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; + // Compute total number of assignments + size_t nrAssignments = 0; + for (auto branch : f->branches()) { + if (auto leaf = std::dynamic_pointer_cast(branch)) { + nrAssignments += leaf->nrAssignments(); + } + } NodePtr newLeaf( - new Leaf(std::dynamic_pointer_cast(f0)->constant())); + new Leaf(std::dynamic_pointer_cast(f0)->constant(), + nrAssignments)); return newLeaf; } #endif @@ -709,7 +730,7 @@ namespace gtsam { // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = std::dynamic_pointer_cast(f)) { - return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); } // Check if Choice @@ -856,6 +877,16 @@ namespace gtsam { return total; } + /****************************************************************************/ + template + size_t DecisionTree::nrAssignments() const { + size_t n = 0; + this->visitLeaf([&n](const DecisionTree::Leaf& leaf) { + n += leaf.nrAssignments(); + }); + return n; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 733ab1ad1..9a8eac65e 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -307,6 +307,42 @@ namespace gtsam { /// Return the number of leaves in the tree. size_t nrLeaves() const; + /** + * @brief This is a convenience function which returns the total number of + * leaf assignments in the decision tree. + * This function is not used for anymajor operations within the discrete + * factor graph framework. + * + * Leaf assignments represent the cardinality of each leaf node, e.g. in a + * binary tree each leaf has 2 assignments. This includes counts removed + * from implicit pruning hence, it will always be >= nrLeaves(). + * + * E.g. we have a decision tree as below, where each node has 2 branches: + * + * Choice(m1) + * 0 Choice(m0) + * 0 0 Leaf 0.0 + * 0 1 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * In the unpruned form, the tree will have 4 assignments, 2 for each key, + * and 4 leaves. + * + * In the pruned form, the number of assignments is still 4 but the number + * of leaves is now 3, as below: + * + * Choice(m1) + * 0 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * @return size_t + */ + size_t nrAssignments() const; + /** * @brief Fold a binary function over the tree, returning accumulator. * diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index cdbad8c1c..653360fb7 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -328,6 +328,59 @@ TEST(DecisionTree, Containers) { StringContainerTree converted(stringIntTree, container_of_int); } +/* ************************************************************************** */ +// Test nrAssignments. +TEST(DecisionTree, NrAssignments) { + const std::pair A("A", 2), B("B", 2), C("C", 2); + DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); + + EXPECT_LONGS_EQUAL(8, tree.nrAssignments()); + +#ifdef GTSAM_DT_MERGING + EXPECT(tree.root_->isLeaf()); + auto leaf = std::dynamic_pointer_cast(tree.root_); + EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); +#endif + + DT tree2({C, B, A}, "1 1 1 2 3 4 5 5"); + /* The tree is + Choice(C) + 0 Choice(B) + 0 0 Leaf 1 + 0 1 Choice(A) + 0 1 0 Leaf 1 + 0 1 1 Leaf 2 + 1 Choice(B) + 1 0 Choice(A) + 1 0 0 Leaf 3 + 1 0 1 Leaf 4 + 1 1 Leaf 5 + */ + + EXPECT_LONGS_EQUAL(8, tree2.nrAssignments()); + + auto root = std::dynamic_pointer_cast(tree2.root_); + CHECK(root); + auto choice0 = std::dynamic_pointer_cast(root->branches()[0]); + CHECK(choice0); + +#ifdef GTSAM_DT_MERGING + EXPECT(choice0->branches()[0]->isLeaf()); + auto choice00 = std::dynamic_pointer_cast(choice0->branches()[0]); + CHECK(choice00); + EXPECT_LONGS_EQUAL(2, choice00->nrAssignments()); + + auto choice1 = std::dynamic_pointer_cast(root->branches()[1]); + CHECK(choice1); + auto choice10 = std::dynamic_pointer_cast(choice1->branches()[0]); + CHECK(choice10); + auto choice11 = std::dynamic_pointer_cast(choice1->branches()[1]); + CHECK(choice11); + EXPECT(choice11->isLeaf()); + EXPECT_LONGS_EQUAL(2, choice11->nrAssignments()); +#endif +} + /* ************************************************************************** */ // Test visit. TEST(DecisionTree, visit) { @@ -540,6 +593,38 @@ TEST(DecisionTree, ApplyWithAssignment) { #endif } +/* ************************************************************************** */ +// Test number of assignments. +TEST(DecisionTree, NrAssignments2) { + using gtsam::symbol_shorthand::M; + + std::vector probs = {0, 0, 1, 2}; + + /* Create the decision tree + Choice(m1) + 0 Leaf 0.000000 + 1 Choice(m0) + 1 0 Leaf 1.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; + DecisionTree dt1(keys, probs); + EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); + + /* Create the DecisionTree + Choice(m1) + 0 Choice(m0) + 0 0 Leaf 0.000000 + 0 1 Leaf 1.000000 + 1 Choice(m0) + 1 0 Leaf 0.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; + DecisionTree dt2(keys2, probs); + EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 341eb63e3..6e8621595 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -349,6 +349,119 @@ TEST(DiscreteFactorGraph, markdown) { EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); } +TEST(DiscreteFactorGraph, NrAssignments) { +#ifdef GTSAM_DT_MERGING + string expected_dfg = R"( +size: 2 +factor 0: f[ (m0,2), (m1,2), (m2,2), ] + Choice(m2) + 0 Choice(m1) + 0 0 Leaf [2] 0 + 0 1 Choice(m0) + 0 1 0 Leaf [1]0.27527634 + 0 1 1 Leaf [1] 0 + 1 Choice(m1) + 1 0 Leaf [2] 0 + 1 1 Choice(m0) + 1 1 0 Leaf [1]0.44944733 + 1 1 1 Leaf [1]0.27527634 +factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Leaf [2] 1 + 0 0 1 Leaf [2]0.015366387 + 0 1 Choice(m1) + 0 1 0 Leaf [2] 1 + 0 1 1 Choice(m0) + 0 1 1 0 Leaf [1] 1 + 0 1 1 1 Leaf [1]0.015365663 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Leaf [2] 1 + 1 0 1 Choice(m0) + 1 0 1 0 Leaf [1]0.0094115739 + 1 0 1 1 Leaf [1]0.0094115652 + 1 1 Choice(m1) + 1 1 0 Leaf [2] 1 + 1 1 1 Choice(m0) + 1 1 1 0 Leaf [1] 1 + 1 1 1 1 Leaf [1]0.009321081 +)"; +#else + string expected_dfg = R"( +size: 2 +factor 0: f[ (m0,2), (m1,2), (m2,2), ] + Choice(m2) + 0 Choice(m1) + 0 0 Choice(m0) + 0 0 0 Leaf [1] 0 + 0 0 1 Leaf [1] 0 + 0 1 Choice(m0) + 0 1 0 Leaf [1]0.27527634 + 0 1 1 Leaf [1]0.44944733 + 1 Choice(m1) + 1 0 Choice(m0) + 1 0 0 Leaf [1] 0 + 1 0 1 Leaf [1] 0 + 1 1 Choice(m0) + 1 1 0 Leaf [1] 0 + 1 1 1 Leaf [1]0.27527634 +factor 1: f[ (m0,2), (m1,2), (m2,2), (m3,2), ] + Choice(m3) + 0 Choice(m2) + 0 0 Choice(m1) + 0 0 0 Choice(m0) + 0 0 0 0 Leaf [1] 1 + 0 0 0 1 Leaf [1] 1 + 0 0 1 Choice(m0) + 0 0 1 0 Leaf [1]0.015366387 + 0 0 1 1 Leaf [1]0.015366387 + 0 1 Choice(m1) + 0 1 0 Choice(m0) + 0 1 0 0 Leaf [1] 1 + 0 1 0 1 Leaf [1] 1 + 0 1 1 Choice(m0) + 0 1 1 0 Leaf [1] 1 + 0 1 1 1 Leaf [1]0.015365663 + 1 Choice(m2) + 1 0 Choice(m1) + 1 0 0 Choice(m0) + 1 0 0 0 Leaf [1] 1 + 1 0 0 1 Leaf [1] 1 + 1 0 1 Choice(m0) + 1 0 1 0 Leaf [1]0.0094115739 + 1 0 1 1 Leaf [1]0.0094115652 + 1 1 Choice(m1) + 1 1 0 Choice(m0) + 1 1 0 0 Leaf [1] 1 + 1 1 0 1 Leaf [1] 1 + 1 1 1 Choice(m0) + 1 1 1 0 Leaf [1] 1 + 1 1 1 1 Leaf [1]0.009321081 +)"; +#endif + + DiscreteKeys d0{{M(0), 2}, {M(1), 2}, {M(2), 2}}; + std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; + AlgebraicDecisionTree dt(d0, p0); + //TODO(Varun) Passing ADT to DiscreteConditional causes nrAssignments to get messed up + // Issue seems to be in DecisionTreeFactor.cpp L104 + DiscreteConditional f0(3, DecisionTreeFactor(d0, dt)); + + DiscreteKeys d1{{M(0), 2}, {M(1), 2}, {M(2), 2}, {M(3), 2}}; + std::vector p1 = { + 1, 1, 1, 1, 0.015366387, 0.0094115739, 1, 1, + 1, 1, 1, 1, 0.015366387, 0.0094115652, 0.015365663, 0.009321081}; + DecisionTreeFactor f1(d1, p1); + + DiscreteFactorGraph dfg; + dfg.add(f0); + dfg.add(f1); + + EXPECT(assert_print_equal(expected_dfg, dfg)); +} + /* ************************************************************************* */ int main() { TestResult tr;