Revert "remove nrAssignments from the DecisionTree"

This reverts commit 647d3c0744.
release/4.3a0
Varun Agrawal 2023-07-10 19:39:28 -04:00
parent 2940e69a73
commit 2f4133fd49
3 changed files with 80 additions and 1 deletions

View File

@ -93,7 +93,8 @@ namespace gtsam {
/// print /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf [" << nrAssignments() << "]"
<< valueFormatter(constant_) << std::endl;
} }
/** Write graphviz format to stream `os`. */ /** Write graphviz format to stream `os`. */
@ -827,6 +828,16 @@ namespace gtsam {
return total; return total;
} }
/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}
/****************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>

View File

@ -299,6 +299,42 @@ namespace gtsam {
/// Return the number of leaves in the tree. /// Return the number of leaves in the tree.
size_t nrLeaves() const; size_t nrLeaves() const;
/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;
/** /**
* @brief Fold a binary function over the tree, returning accumulator. * @brief Fold a binary function over the tree, returning accumulator.
* *

View File

@ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) {
EXPECT_LONGS_EQUAL(5, count); EXPECT_LONGS_EQUAL(5, count);
} }
/* ************************************************************************** */
// Test number of assignments.
TEST(DecisionTree, NrAssignments2) {
using gtsam::symbol_shorthand::M;
std::vector<double> probs = {0, 0, 1, 2};
/* Create the decision tree
Choice(m1)
0 Leaf 0.000000
1 Choice(m0)
1 0 Leaf 1.000000
1 1 Leaf 2.000000
*/
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
DecisionTree<Key, double> dt1(keys, probs);
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
/* Create the DecisionTree
Choice(m1)
0 Choice(m0)
0 0 Leaf 0.000000
0 1 Leaf 1.000000
1 Choice(m0)
1 0 Leaf 0.000000
1 1 Leaf 2.000000
*/
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
DecisionTree<Key, double> dt2(keys2, probs);
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;