Merge pull request #1154 from borglab/decisiontree/nrAssignments

Record number of assignments for each leaf
release/4.3a0
Varun Agrawal 2022-04-02 10:01:50 -04:00 committed by GitHub
commit 99c01c4dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 10 deletions

View File

@ -59,15 +59,23 @@ namespace gtsam {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
/** Constructor from constant */ /** The number of assignments contained within this leaf
Leaf(const Y& constant) : * Particularly useful when leaves have been pruned.
constant_(constant) {} */
size_t nrAssignments_;
/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */ /** return the constant */
const Y& constant() const { const Y& constant() const {
return constant_; return constant_;
} }
/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }
/// Leaf-Leaf equality /// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override { bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_; return constant_ == q.constant_;
@ -108,14 +116,14 @@ namespace gtsam {
/** apply unary operator */ /** apply unary operator */
NodePtr apply(const Unary& op) const override { NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_))); NodePtr f(new Leaf(op(constant_), nrAssignments_));
return f; return f;
} }
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override { const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_))); NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
return f; return f;
} }
@ -130,7 +138,8 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL // fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
return h; return h;
} }
@ -141,7 +150,7 @@ namespace gtsam {
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant())); return NodePtr(new Leaf(constant(), nrAssignments()));
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
@ -178,9 +187,16 @@ namespace gtsam {
if (f->allSame_) { if (f->allSame_) {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());
size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
}
NodePtr newLeaf( NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -640,7 +656,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf. // If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) { if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
} }
// Check if Choice // Check if Choice

View File

@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ************************************************************************** */
// Test nrAssignments.
TEST(DecisionTree, NrAssignments) {
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf());
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
/* The tree is
Choice(C)
0 Choice(B)
0 0 Leaf 1
0 1 Choice(A)
0 1 0 Leaf 1
0 1 1 Leaf 2
1 Choice(B)
1 0 Choice(A)
1 0 0 Leaf 3
1 0 1 Leaf 4
1 1 Leaf 5
*/
auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
CHECK(choice0);
EXPECT(choice0->branches()[0]->isLeaf());
auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
CHECK(choice00);
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
CHECK(choice1);
auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
CHECK(choice10);
auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
CHECK(choice11);
EXPECT(choice11->isLeaf());
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
}
/* ************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {