parent
2940e69a73
commit
2f4133fd49
|
@ -93,7 +93,8 @@ namespace gtsam {
|
|||
/// print
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
std::cout << s << " Leaf [" << nrAssignments() << "]"
|
||||
<< valueFormatter(constant_) << std::endl;
|
||||
}
|
||||
|
||||
/** Write graphviz format to stream `os`. */
|
||||
|
@ -827,6 +828,16 @@ namespace gtsam {
|
|||
return total;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||
size_t n = 0;
|
||||
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||
n += leaf.nrAssignments();
|
||||
});
|
||||
return n;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// fold is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
|
|
|
@ -299,6 +299,42 @@ namespace gtsam {
|
|||
/// Return the number of leaves in the tree.
|
||||
size_t nrLeaves() const;
|
||||
|
||||
/**
|
||||
* @brief This is a convenience function which returns the total number of
|
||||
* leaf assignments in the decision tree.
|
||||
* This function is not used for anymajor operations within the discrete
|
||||
* factor graph framework.
|
||||
*
|
||||
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||
*
|
||||
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Choice(m0)
|
||||
* 0 0 Leaf 0.0
|
||||
* 0 1 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||
* and 4 leaves.
|
||||
*
|
||||
* In the pruned form, the number of assignments is still 4 but the number
|
||||
* of leaves is now 3, as below:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
size_t nrAssignments() const;
|
||||
|
||||
/**
|
||||
* @brief Fold a binary function over the tree, returning accumulator.
|
||||
*
|
||||
|
|
|
@ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
EXPECT_LONGS_EQUAL(5, count);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test number of assignments.
|
||||
TEST(DecisionTree, NrAssignments2) {
|
||||
using gtsam::symbol_shorthand::M;
|
||||
|
||||
std::vector<double> probs = {0, 0, 1, 2};
|
||||
|
||||
/* Create the decision tree
|
||||
Choice(m1)
|
||||
0 Leaf 0.000000
|
||||
1 Choice(m0)
|
||||
1 0 Leaf 1.000000
|
||||
1 1 Leaf 2.000000
|
||||
*/
|
||||
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
||||
DecisionTree<Key, double> dt1(keys, probs);
|
||||
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
||||
|
||||
/* Create the DecisionTree
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf 0.000000
|
||||
0 1 Leaf 1.000000
|
||||
1 Choice(m0)
|
||||
1 0 Leaf 0.000000
|
||||
1 1 Leaf 2.000000
|
||||
*/
|
||||
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
||||
DecisionTree<Key, double> dt2(keys2, probs);
|
||||
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue