Merge pull request #1154 from borglab/decisiontree/nrAssignments

Record number of assignments for each leaf
release/4.3a0
Varun Agrawal 2022-04-02 10:01:50 -04:00 committed by GitHub
commit 99c01c4dba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 10 deletions

View File

@ -59,15 +59,23 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;
/** Constructor from constant */
Leaf(const Y& constant) :
constant_(constant) {}
/** The number of assignments contained within this leaf
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;
/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */
const Y& constant() const {
return constant_;
}
/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }
/// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_;
@ -108,14 +116,14 @@ namespace gtsam {
/** apply unary operator */
NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_)));
NodePtr f(new Leaf(op(constant_), nrAssignments_));
return f;
}
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_)));
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
return f;
}
@ -130,7 +138,8 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
return h;
}
@ -141,7 +150,7 @@ namespace gtsam {
/** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant()));
return NodePtr(new Leaf(constant(), nrAssignments()));
}
bool isLeaf() const override { return true; }
@ -178,9 +187,16 @@ namespace gtsam {
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
assert(f0->isLeaf());
size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
}
NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
#endif
@ -640,7 +656,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
}
// Check if Choice

View File

@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int);
}
/* ************************************************************************** */
// Test nrAssignments.
TEST(DecisionTree, NrAssignments) {
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf());
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
/* The tree is
Choice(C)
0 Choice(B)
0 0 Leaf 1
0 1 Choice(A)
0 1 0 Leaf 1
0 1 1 Leaf 2
1 Choice(B)
1 0 Choice(A)
1 0 0 Leaf 3
1 0 1 Leaf 4
1 1 Leaf 5
*/
auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
CHECK(choice0);
EXPECT(choice0->branches()[0]->isLeaf());
auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
CHECK(choice00);
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
CHECK(choice1);
auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
CHECK(choice10);
auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
CHECK(choice11);
EXPECT(choice11->isLeaf());
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
}
/* ************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {