parent
2940e69a73
commit
2f4133fd49
|
@ -93,7 +93,8 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf [" << nrAssignments() << "]"
|
||||||
|
<< valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
@ -827,6 +828,16 @@ namespace gtsam {
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||||
|
size_t n = 0;
|
||||||
|
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||||
|
n += leaf.nrAssignments();
|
||||||
|
});
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
|
@ -299,6 +299,42 @@ namespace gtsam {
|
||||||
/// Return the number of leaves in the tree.
|
/// Return the number of leaves in the tree.
|
||||||
size_t nrLeaves() const;
|
size_t nrLeaves() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This is a convenience function which returns the total number of
|
||||||
|
* leaf assignments in the decision tree.
|
||||||
|
* This function is not used for anymajor operations within the discrete
|
||||||
|
* factor graph framework.
|
||||||
|
*
|
||||||
|
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||||
|
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||||
|
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||||
|
*
|
||||||
|
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Choice(m0)
|
||||||
|
* 0 0 Leaf 0.0
|
||||||
|
* 0 1 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||||
|
* and 4 leaves.
|
||||||
|
*
|
||||||
|
* In the pruned form, the number of assignments is still 4 but the number
|
||||||
|
* of leaves is now 3, as below:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* @return size_t
|
||||||
|
*/
|
||||||
|
size_t nrAssignments() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Fold a binary function over the tree, returning accumulator.
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
*
|
*
|
||||||
|
|
|
@ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
EXPECT_LONGS_EQUAL(5, count);
|
EXPECT_LONGS_EQUAL(5, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test number of assignments.
|
||||||
|
TEST(DecisionTree, NrAssignments2) {
|
||||||
|
using gtsam::symbol_shorthand::M;
|
||||||
|
|
||||||
|
std::vector<double> probs = {0, 0, 1, 2};
|
||||||
|
|
||||||
|
/* Create the decision tree
|
||||||
|
Choice(m1)
|
||||||
|
0 Leaf 0.000000
|
||||||
|
1 Choice(m0)
|
||||||
|
1 0 Leaf 1.000000
|
||||||
|
1 1 Leaf 2.000000
|
||||||
|
*/
|
||||||
|
DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
|
||||||
|
DecisionTree<Key, double> dt1(keys, probs);
|
||||||
|
EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());
|
||||||
|
|
||||||
|
/* Create the DecisionTree
|
||||||
|
Choice(m1)
|
||||||
|
0 Choice(m0)
|
||||||
|
0 0 Leaf 0.000000
|
||||||
|
0 1 Leaf 1.000000
|
||||||
|
1 Choice(m0)
|
||||||
|
1 0 Leaf 0.000000
|
||||||
|
1 1 Leaf 2.000000
|
||||||
|
*/
|
||||||
|
DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
|
||||||
|
DecisionTree<Key, double> dt2(keys2, probs);
|
||||||
|
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue