New prune method for DecisionTreeFactor
parent
a9a4075ff6
commit
4c966b9753
|
@ -286,5 +286,43 @@ namespace gtsam {
|
|||
AlgebraicDecisionTree<Key>(keys, table),
|
||||
cardinalities_(keys.cardinalities()) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
|
||||
const size_t N = maxNrLeaves;
|
||||
|
||||
// Let's assume that the structure of the last discrete density will be the
|
||||
// same as the last continuous
|
||||
std::vector<double> probabilities;
|
||||
// number of choices
|
||||
this->visit([&](const double& prob) { probabilities.emplace_back(prob); });
|
||||
|
||||
// The number of probabilities can be lower than max_leaves
|
||||
if (probabilities.size() <= N) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::sort(probabilities.begin(), probabilities.end(),
|
||||
std::greater<double>{});
|
||||
|
||||
double threshold = probabilities[N - 1];
|
||||
|
||||
// Now threshold the decision tree
|
||||
size_t total = 0;
|
||||
auto thresholdFunc = [threshold, &total, N](const double& value) {
|
||||
if (value < threshold || total >= N) {
|
||||
return 0.0;
|
||||
} else {
|
||||
total += 1;
|
||||
return value;
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> thresholded(*this, thresholdFunc);
|
||||
|
||||
// Create pruned decision tree factor
|
||||
DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded);
|
||||
|
||||
return prunedDiscreteFactor;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -170,6 +170,18 @@ namespace gtsam {
|
|||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of discrete variables.
|
||||
*
|
||||
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
|
||||
* probability.
|
||||
* A leaf is pruned if it is not in the top `maxNrLeaves` values.
|
||||
*
|
||||
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
*/
|
||||
DecisionTreeFactor prune(size_t maxNrLeaves) const;
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check pruning of the decision tree works as expected.
|
||||
TEST(DecisionTreeFactor, Prune) {
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
|
||||
// Only keep the leaves with the top 5 values.
|
||||
size_t maxNrLeaves = 5;
|
||||
auto pruned5 = f.prune(maxNrLeaves);
|
||||
|
||||
// Pruned leaves should be 0
|
||||
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
|
||||
EXPECT(assert_equal(expected, pruned5));
|
||||
|
||||
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||
maxNrLeaves = 2;
|
||||
auto pruned2 = f.prune(maxNrLeaves);
|
||||
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||
EXPECT(assert_equal(expected2, pruned2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DecisionTreeFactor, DotWithNames) {
|
||||
DiscreteKey A(12, 3), B(5, 2);
|
||||
|
|
Loading…
Reference in New Issue