diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9d618dea0..19e3e2887 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,8 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; + std::cout << s << " Leaf [" << nrAssignments() << "]" + << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ @@ -827,6 +828,16 @@ namespace gtsam { return total; } + /****************************************************************************/ + template + size_t DecisionTree::nrAssignments() const { + size_t n = 0; + this->visitLeaf([&n](const DecisionTree::Leaf& leaf) { + n += leaf.nrAssignments(); + }); + return n; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ed1908485..bee0ce5c7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -299,6 +299,42 @@ namespace gtsam { /// Return the number of leaves in the tree. size_t nrLeaves() const; + /** + * @brief This is a convenience function which returns the total number of + * leaf assignments in the decision tree. + * This function is not used for anymajor operations within the discrete + * factor graph framework. + * + * Leaf assignments represent the cardinality of each leaf node, e.g. in a + * binary tree each leaf has 2 assignments. This includes counts removed + * from implicit pruning hence, it will always be >= nrLeaves(). + * + * E.g. we have a decision tree as below, where each node has 2 branches: + * + * Choice(m1) + * 0 Choice(m0) + * 0 0 Leaf 0.0 + * 0 1 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * In the unpruned form, the tree will have 4 assignments, 2 for each key, + * and 4 leaves. + * + * In the pruned form, the number of assignments is still 4 but the number + * of leaves is now 3, as below: + * + * Choice(m1) + * 0 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * @return size_t + */ + size_t nrAssignments() const; + /** * @brief Fold a binary function over the tree, returning accumulator. * diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index d2a94ddc3..395915068 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) { EXPECT_LONGS_EQUAL(5, count); } +/* ************************************************************************** */ +// Test number of assignments. +TEST(DecisionTree, NrAssignments2) { + using gtsam::symbol_shorthand::M; + + std::vector probs = {0, 0, 1, 2}; + + /* Create the decision tree + Choice(m1) + 0 Leaf 0.000000 + 1 Choice(m0) + 1 0 Leaf 1.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; + DecisionTree dt1(keys, probs); + EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); + + /* Create the DecisionTree + Choice(m1) + 0 Choice(m0) + 0 0 Leaf 0.000000 + 0 1 Leaf 1.000000 + 1 Choice(m0) + 1 0 Leaf 0.000000 + 1 1 Leaf 2.000000 + */ + DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; + DecisionTree dt2(keys2, probs); + EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); +} + /* ************************************************************************* */ int main() { TestResult tr;