Merge pull request #1154 from borglab/decisiontree/nrAssignments
Record number of assignments for each leafrelease/4.3a0
commit
99c01c4dba
|
@ -59,15 +59,23 @@ namespace gtsam {
|
|||
/** constant stored in this leaf */
|
||||
Y constant_;
|
||||
|
||||
/** Constructor from constant */
|
||||
Leaf(const Y& constant) :
|
||||
constant_(constant) {}
|
||||
/** The number of assignments contained within this leaf
|
||||
* Particularly useful when leaves have been pruned.
|
||||
*/
|
||||
size_t nrAssignments_;
|
||||
|
||||
/// Constructor from constant
|
||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
||||
/** return the constant */
|
||||
const Y& constant() const {
|
||||
return constant_;
|
||||
}
|
||||
|
||||
/// Return the number of assignments contained within this leaf.
|
||||
size_t nrAssignments() const { return nrAssignments_; }
|
||||
|
||||
/// Leaf-Leaf equality
|
||||
bool sameLeaf(const Leaf& q) const override {
|
||||
return constant_ == q.constant_;
|
||||
|
@ -108,14 +116,14 @@ namespace gtsam {
|
|||
|
||||
/** apply unary operator */
|
||||
NodePtr apply(const Unary& op) const override {
|
||||
NodePtr f(new Leaf(op(constant_)));
|
||||
NodePtr f(new Leaf(op(constant_), nrAssignments_));
|
||||
return f;
|
||||
}
|
||||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const override {
|
||||
NodePtr f(new Leaf(op(choices, constant_)));
|
||||
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -130,7 +138,8 @@ namespace gtsam {
|
|||
|
||||
// Applying binary operator to two leaves results in a leaf
|
||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
|
||||
// fL op gL
|
||||
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
|
||||
return h;
|
||||
}
|
||||
|
||||
|
@ -141,7 +150,7 @@ namespace gtsam {
|
|||
|
||||
/** choose a branch, create new memory ! */
|
||||
NodePtr choose(const L& label, size_t index) const override {
|
||||
return NodePtr(new Leaf(constant()));
|
||||
return NodePtr(new Leaf(constant(), nrAssignments()));
|
||||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
|
@ -178,9 +187,16 @@ namespace gtsam {
|
|||
if (f->allSame_) {
|
||||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
assert(f0->isLeaf());
|
||||
|
||||
size_t nrAssignments = 0;
|
||||
for(auto branch: f->branches()) {
|
||||
assert(branch->isLeaf());
|
||||
nrAssignments +=
|
||||
boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
|
||||
}
|
||||
NodePtr newLeaf(
|
||||
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(),
|
||||
nrAssignments));
|
||||
return newLeaf;
|
||||
} else
|
||||
#endif
|
||||
|
@ -640,7 +656,7 @@ namespace gtsam {
|
|||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
|
||||
}
|
||||
|
||||
// Check if Choice
|
||||
|
|
|
@ -323,6 +323,49 @@ TEST(DecisionTree, Containers) {
|
|||
StringContainerTree converted(stringIntTree, container_of_int);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test nrAssignments.
|
||||
TEST(DecisionTree, NrAssignments) {
|
||||
pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
|
||||
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
|
||||
EXPECT(tree.root_->isLeaf());
|
||||
auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
|
||||
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
|
||||
|
||||
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
|
||||
/* The tree is
|
||||
Choice(C)
|
||||
0 Choice(B)
|
||||
0 0 Leaf 1
|
||||
0 1 Choice(A)
|
||||
0 1 0 Leaf 1
|
||||
0 1 1 Leaf 2
|
||||
1 Choice(B)
|
||||
1 0 Choice(A)
|
||||
1 0 0 Leaf 3
|
||||
1 0 1 Leaf 4
|
||||
1 1 Leaf 5
|
||||
*/
|
||||
|
||||
auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
|
||||
CHECK(root);
|
||||
auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
|
||||
CHECK(choice0);
|
||||
EXPECT(choice0->branches()[0]->isLeaf());
|
||||
auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
|
||||
CHECK(choice00);
|
||||
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
|
||||
|
||||
auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
|
||||
CHECK(choice1);
|
||||
auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
|
||||
CHECK(choice10);
|
||||
auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
|
||||
CHECK(choice11);
|
||||
EXPECT(choice11->isLeaf());
|
||||
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test visit.
|
||||
TEST(DecisionTree, visit) {
|
||||
|
|
Loading…
Reference in New Issue