Merge pull request #1154 from borglab/decisiontree/nrAssignments
Record number of assignments for each leafrelease/4.3a0
						commit
						99c01c4dba
					
				|  | @ -59,15 +59,23 @@ namespace gtsam { | |||
|     /** constant stored in this leaf */ | ||||
|     Y constant_; | ||||
| 
 | ||||
|     /** Constructor from constant */ | ||||
|     Leaf(const Y& constant) : | ||||
|       constant_(constant) {} | ||||
|     /** The number of assignments contained within this leaf
 | ||||
|      * Particularly useful when leaves have been pruned. | ||||
|      */ | ||||
|     size_t nrAssignments_; | ||||
| 
 | ||||
|     /// Constructor from constant
 | ||||
|     Leaf(const Y& constant, size_t nrAssignments = 1) | ||||
|         : constant_(constant), nrAssignments_(nrAssignments) {} | ||||
| 
 | ||||
|     /** return the constant */ | ||||
|     const Y& constant() const { | ||||
|       return constant_; | ||||
|     } | ||||
| 
 | ||||
|     /// Return the number of assignments contained within this leaf.
 | ||||
|     size_t nrAssignments() const { return nrAssignments_; } | ||||
| 
 | ||||
|     /// Leaf-Leaf equality
 | ||||
|     bool sameLeaf(const Leaf& q) const override { | ||||
|       return constant_ == q.constant_; | ||||
|  | @ -108,14 +116,14 @@ namespace gtsam { | |||
| 
 | ||||
|     /** apply unary operator */ | ||||
|     NodePtr apply(const Unary& op) const override { | ||||
|       NodePtr f(new Leaf(op(constant_))); | ||||
|       NodePtr f(new Leaf(op(constant_), nrAssignments_)); | ||||
|       return f; | ||||
|     } | ||||
| 
 | ||||
|     /// Apply unary operator with assignment
 | ||||
|     NodePtr apply(const UnaryAssignment& op, | ||||
|                   const Assignment<L>& choices) const override { | ||||
|       NodePtr f(new Leaf(op(choices, constant_))); | ||||
|       NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); | ||||
|       return f; | ||||
|     } | ||||
| 
 | ||||
|  | @ -130,7 +138,8 @@ namespace gtsam { | |||
| 
 | ||||
|     // Applying binary operator to two leaves results in a leaf
 | ||||
|     NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { | ||||
|       NodePtr h(new Leaf(op(fL.constant_, constant_)));  // fL op gL
 | ||||
|       // fL op gL
 | ||||
|       NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_)); | ||||
|       return h; | ||||
|     } | ||||
| 
 | ||||
|  | @ -141,7 +150,7 @@ namespace gtsam { | |||
| 
 | ||||
|     /** choose a branch, create new memory ! */ | ||||
|     NodePtr choose(const L& label, size_t index) const override { | ||||
|       return NodePtr(new Leaf(constant())); | ||||
|       return NodePtr(new Leaf(constant(), nrAssignments())); | ||||
|     } | ||||
| 
 | ||||
|     bool isLeaf() const override { return true; } | ||||
|  | @ -178,9 +187,16 @@ namespace gtsam { | |||
|       if (f->allSame_) { | ||||
|         assert(f->branches().size() > 0); | ||||
|         NodePtr f0 = f->branches_[0]; | ||||
|         assert(f0->isLeaf()); | ||||
| 
 | ||||
|         size_t nrAssignments = 0; | ||||
|         for(auto branch: f->branches()) { | ||||
|           assert(branch->isLeaf()); | ||||
|           nrAssignments += | ||||
|               boost::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments(); | ||||
|         } | ||||
|         NodePtr newLeaf( | ||||
|             new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); | ||||
|             new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant(), | ||||
|                      nrAssignments)); | ||||
|         return newLeaf; | ||||
|       } else | ||||
| #endif | ||||
|  | @ -640,7 +656,7 @@ namespace gtsam { | |||
|     // If leaf, apply unary conversion "op" and create a unique leaf.
 | ||||
|     using MXLeaf = typename DecisionTree<M, X>::Leaf; | ||||
|     if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) { | ||||
|       return NodePtr(new Leaf(Y_of_X(leaf->constant()))); | ||||
|       return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); | ||||
|     } | ||||
| 
 | ||||
|     // Check if Choice
 | ||||
|  |  | |||
|  | @ -323,6 +323,49 @@ TEST(DecisionTree, Containers) { | |||
|   StringContainerTree converted(stringIntTree, container_of_int); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test nrAssignments.
 | ||||
| TEST(DecisionTree, NrAssignments) { | ||||
|   pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); | ||||
|   DT tree({A, B, C}, "1 1 1 1 1 1 1 1"); | ||||
|   EXPECT(tree.root_->isLeaf()); | ||||
|   auto leaf = boost::dynamic_pointer_cast<const DT::Leaf>(tree.root_); | ||||
|   EXPECT_LONGS_EQUAL(8, leaf->nrAssignments()); | ||||
| 
 | ||||
|   DT tree2({C, B, A}, "1 1 1 2 3 4 5 5"); | ||||
|   /* The tree is
 | ||||
|     Choice(C)  | ||||
|     0 Choice(B)  | ||||
|     0 0 Leaf 1 | ||||
|     0 1 Choice(A)  | ||||
|     0 1 0 Leaf 1 | ||||
|     0 1 1 Leaf 2 | ||||
|     1 Choice(B)  | ||||
|     1 0 Choice(A)  | ||||
|     1 0 0 Leaf 3 | ||||
|     1 0 1 Leaf 4 | ||||
|     1 1 Leaf 5 | ||||
|   */ | ||||
| 
 | ||||
|   auto root = boost::dynamic_pointer_cast<const DT::Choice>(tree2.root_); | ||||
|   CHECK(root); | ||||
|   auto choice0 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]); | ||||
|   CHECK(choice0); | ||||
|   EXPECT(choice0->branches()[0]->isLeaf()); | ||||
|   auto choice00 = boost::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]); | ||||
|   CHECK(choice00); | ||||
|   EXPECT_LONGS_EQUAL(2, choice00->nrAssignments()); | ||||
| 
 | ||||
|   auto choice1 = boost::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]); | ||||
|   CHECK(choice1); | ||||
|   auto choice10 = boost::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]); | ||||
|   CHECK(choice10); | ||||
|   auto choice11 = boost::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]); | ||||
|   CHECK(choice11); | ||||
|   EXPECT(choice11->isLeaf()); | ||||
|   EXPECT_LONGS_EQUAL(2, choice11->nrAssignments()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit.
 | ||||
| TEST(DecisionTree, visit) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue