New prune method for DecisionTreeFactor
							parent
							
								
									a9a4075ff6
								
							
						
					
					
						commit
						4c966b9753
					
				|  | @ -286,5 +286,43 @@ namespace gtsam { | |||
|         AlgebraicDecisionTree<Key>(keys, table), | ||||
|         cardinalities_(keys.cardinalities()) {} | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
|   DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { | ||||
|     const size_t N = maxNrLeaves; | ||||
| 
 | ||||
|     // Let's assume that the structure of the last discrete density will be the
 | ||||
|     // same as the last continuous
 | ||||
|     std::vector<double> probabilities; | ||||
|     // number of choices
 | ||||
|     this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); | ||||
| 
 | ||||
|     // The number of probabilities can be lower than max_leaves
 | ||||
|     if (probabilities.size() <= N) { | ||||
|       return *this; | ||||
|     } | ||||
| 
 | ||||
|     std::sort(probabilities.begin(), probabilities.end(), | ||||
|               std::greater<double>{}); | ||||
| 
 | ||||
|     double threshold = probabilities[N - 1]; | ||||
| 
 | ||||
|     // Now threshold the decision tree
 | ||||
|     size_t total = 0; | ||||
|     auto thresholdFunc = [threshold, &total, N](const double& value) { | ||||
|       if (value < threshold || total >= N) { | ||||
|         return 0.0; | ||||
|       } else { | ||||
|         total += 1; | ||||
|         return value; | ||||
|       } | ||||
|     }; | ||||
|     DecisionTree<Key, double> thresholded(*this, thresholdFunc); | ||||
| 
 | ||||
|     // Create pruned decision tree factor
 | ||||
|     DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded); | ||||
| 
 | ||||
|     return prunedDiscreteFactor; | ||||
|   } | ||||
| 
 | ||||
|   /* ************************************************************************ */ | ||||
| }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -170,6 +170,18 @@ namespace gtsam { | |||
|     /// Return all the discrete keys associated with this factor.
 | ||||
|     DiscreteKeys discreteKeys() const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Prune the decision tree of discrete variables. | ||||
|      * | ||||
|      * Pruning will set the leaves to be "pruned" to 0 indicating a 0 | ||||
|      * probability. | ||||
|      * A leaf is pruned if it is not in the top `maxNrLeaves` values. | ||||
|      * | ||||
|      * @param maxNrLeaves The maximum number of leaves to keep. | ||||
|      * @return DecisionTreeFactor::shared_ptr | ||||
|      */ | ||||
|     DecisionTreeFactor prune(size_t maxNrLeaves) const; | ||||
| 
 | ||||
|     /// @}
 | ||||
|     /// @name Wrapper support
 | ||||
|     /// @{
 | ||||
|  |  | |||
|  | @ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) { | |||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check pruning of the decision tree works as expected.
 | ||||
| TEST(DecisionTreeFactor, Prune) { | ||||
|   DiscreteKey A(1, 2), B(2, 2), C(3, 2); | ||||
|   DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); | ||||
| 
 | ||||
|   // Only keep the leaves with the top 5 values.
 | ||||
|   size_t maxNrLeaves = 5; | ||||
|   auto pruned5 = f.prune(maxNrLeaves); | ||||
| 
 | ||||
|   // Pruned leaves should be 0
 | ||||
|   DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); | ||||
|   EXPECT(assert_equal(expected, pruned5)); | ||||
| 
 | ||||
|   // Check for more extreme pruning where we only keep the top 2 leaves
 | ||||
|   maxNrLeaves = 2; | ||||
|   auto pruned2 = f.prune(maxNrLeaves); | ||||
|   DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); | ||||
|   EXPECT(assert_equal(expected2, pruned2)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DecisionTreeFactor, DotWithNames) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue