diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9d618dea0..2f55fb3fc 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -649,8 +649,9 @@ namespace gtsam { throw std::invalid_argument("DecisionTree::create invalid argument"); } auto choice = std::make_shared(begin->first, endY - beginY); - for (ValueIt y = beginY; y != endY; y++) + for (ValueIt y = beginY; y != endY; y++) { choice->push_back(NodePtr(new Leaf(*y))); + } return Choice::Unique(choice); } @@ -827,6 +828,16 @@ namespace gtsam { return total; } + /****************************************************************************/ + template + size_t DecisionTree::nrAssignments() const { + size_t n = 0; + this->visitLeaf([&n](const DecisionTree::Leaf& leaf) { + n += leaf.nrAssignments(); + }); + return n; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ed1908485..70d528648 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -299,6 +299,15 @@ namespace gtsam { /// Return the number of leaves in the tree. size_t nrLeaves() const; + /** + * @brief Return the number of total leaf assignments. + * This includes counts removed from implicit pruning hence, + * it will always be >= nrLeaves(). + * + * @return size_t + */ + size_t nrAssignments() const; + /** * @brief Fold a binary function over the tree, returning accumulator. *