parent
							
								
									2940e69a73
								
							
						
					
					
						commit
						2f4133fd49
					
				|  | @ -93,7 +93,8 @@ namespace gtsam { | |||
|     /// print
 | ||||
|     void print(const std::string& s, const LabelFormatter& labelFormatter, | ||||
|                const ValueFormatter& valueFormatter) const override { | ||||
|       std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; | ||||
|       std::cout << s << " Leaf [" << nrAssignments() << "]" | ||||
|                 << valueFormatter(constant_) << std::endl; | ||||
|     } | ||||
| 
 | ||||
|     /** Write graphviz format to stream `os`. */ | ||||
|  | @ -827,6 +828,16 @@ namespace gtsam { | |||
|     return total; | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template <typename L, typename Y> | ||||
|   size_t DecisionTree<L, Y>::nrAssignments() const { | ||||
|     size_t n = 0; | ||||
|     this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) { | ||||
|       n += leaf.nrAssignments(); | ||||
|     }); | ||||
|     return n; | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   // fold is just done with a visit
 | ||||
|   template <typename L, typename Y> | ||||
|  |  | |||
|  | @ -299,6 +299,42 @@ namespace gtsam { | |||
|     /// Return the number of leaves in the tree.
 | ||||
|     size_t nrLeaves() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief This is a convenience function which returns the total number of | ||||
|      * leaf assignments in the decision tree. | ||||
|      * This function is not used for anymajor operations within the discrete | ||||
|      * factor graph framework. | ||||
|      * | ||||
|      * Leaf assignments represent the cardinality of each leaf node, e.g. in a | ||||
|      * binary tree each leaf has 2 assignments. This includes counts removed | ||||
|      * from implicit pruning hence, it will always be >= nrLeaves(). | ||||
|      * | ||||
|      * E.g. we have a decision tree as below, where each node has 2 branches: | ||||
|      * | ||||
|      * Choice(m1) | ||||
|      * 0 Choice(m0) | ||||
|      * 0 0 Leaf 0.0 | ||||
|      * 0 1 Leaf 0.0 | ||||
|      * 1 Choice(m0) | ||||
|      * 1 0 Leaf 1.0 | ||||
|      * 1 1 Leaf 2.0 | ||||
|      * | ||||
|      * In the unpruned form, the tree will have 4 assignments, 2 for each key, | ||||
|      * and 4 leaves. | ||||
|      * | ||||
|      * In the pruned form, the number of assignments is still 4 but the number | ||||
|      * of leaves is now 3, as below: | ||||
|      * | ||||
|      * Choice(m1) | ||||
|      * 0 Leaf 0.0 | ||||
|      * 1 Choice(m0) | ||||
|      * 1 0 Leaf 1.0 | ||||
|      * 1 1 Leaf 2.0 | ||||
|      * | ||||
|      * @return size_t | ||||
|      */ | ||||
|     size_t nrAssignments() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Fold a binary function over the tree, returning accumulator. | ||||
|      * | ||||
|  |  | |||
|  | @ -531,6 +531,38 @@ TEST(DecisionTree, ApplyWithAssignment) { | |||
|   EXPECT_LONGS_EQUAL(5, count); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test number of assignments.
 | ||||
| TEST(DecisionTree, NrAssignments2) { | ||||
|   using gtsam::symbol_shorthand::M; | ||||
| 
 | ||||
|   std::vector<double> probs = {0, 0, 1, 2}; | ||||
| 
 | ||||
|   /* Create the decision tree
 | ||||
|     Choice(m1) | ||||
|     0 Leaf 0.000000 | ||||
|     1 Choice(m0) | ||||
|     1 0 Leaf 1.000000 | ||||
|     1 1 Leaf 2.000000 | ||||
|   */ | ||||
|   DiscreteKeys keys{{M(1), 2}, {M(0), 2}}; | ||||
|   DecisionTree<Key, double> dt1(keys, probs); | ||||
|   EXPECT_LONGS_EQUAL(4, dt1.nrAssignments()); | ||||
| 
 | ||||
|   /* Create the DecisionTree
 | ||||
|     Choice(m1) | ||||
|     0 Choice(m0) | ||||
|     0 0 Leaf 0.000000 | ||||
|     0 1 Leaf 1.000000 | ||||
|     1 Choice(m0) | ||||
|     1 0 Leaf 0.000000 | ||||
|     1 1 Leaf 2.000000 | ||||
|   */ | ||||
|   DiscreteKeys keys2{{M(0), 2}, {M(1), 2}}; | ||||
|   DecisionTree<Key, double> dt2(keys2, probs); | ||||
|   EXPECT_LONGS_EQUAL(4, dt2.nrAssignments()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue