New prune method for DecisionTreeFactor
parent
a9a4075ff6
commit
4c966b9753
|
@ -286,5 +286,43 @@ namespace gtsam {
|
||||||
AlgebraicDecisionTree<Key>(keys, table),
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
cardinalities_(keys.cardinalities()) {}
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
|
||||||
|
const size_t N = maxNrLeaves;
|
||||||
|
|
||||||
|
// Let's assume that the structure of the last discrete density will be the
|
||||||
|
// same as the last continuous
|
||||||
|
std::vector<double> probabilities;
|
||||||
|
// number of choices
|
||||||
|
this->visit([&](const double& prob) { probabilities.emplace_back(prob); });
|
||||||
|
|
||||||
|
// The number of probabilities can be lower than max_leaves
|
||||||
|
if (probabilities.size() <= N) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(probabilities.begin(), probabilities.end(),
|
||||||
|
std::greater<double>{});
|
||||||
|
|
||||||
|
double threshold = probabilities[N - 1];
|
||||||
|
|
||||||
|
// Now threshold the decision tree
|
||||||
|
size_t total = 0;
|
||||||
|
auto thresholdFunc = [threshold, &total, N](const double& value) {
|
||||||
|
if (value < threshold || total >= N) {
|
||||||
|
return 0.0;
|
||||||
|
} else {
|
||||||
|
total += 1;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> thresholded(*this, thresholdFunc);
|
||||||
|
|
||||||
|
// Create pruned decision tree factor
|
||||||
|
DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded);
|
||||||
|
|
||||||
|
return prunedDiscreteFactor;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -170,6 +170,18 @@ namespace gtsam {
|
||||||
/// Return all the discrete keys associated with this factor.
|
/// Return all the discrete keys associated with this factor.
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prune the decision tree of discrete variables.
|
||||||
|
*
|
||||||
|
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
|
||||||
|
* probability.
|
||||||
|
* A leaf is pruned if it is not in the top `maxNrLeaves` values.
|
||||||
|
*
|
||||||
|
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||||
|
* @return DecisionTreeFactor::shared_ptr
|
||||||
|
*/
|
||||||
|
DecisionTreeFactor prune(size_t maxNrLeaves) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check pruning of the decision tree works as expected.
|
||||||
|
TEST(DecisionTreeFactor, Prune) {
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
|
||||||
|
// Only keep the leaves with the top 5 values.
|
||||||
|
size_t maxNrLeaves = 5;
|
||||||
|
auto pruned5 = f.prune(maxNrLeaves);
|
||||||
|
|
||||||
|
// Pruned leaves should be 0
|
||||||
|
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
|
||||||
|
EXPECT(assert_equal(expected, pruned5));
|
||||||
|
|
||||||
|
// Check for more extreme pruning where we only keep the top 2 leaves
|
||||||
|
maxNrLeaves = 2;
|
||||||
|
auto pruned2 = f.prune(maxNrLeaves);
|
||||||
|
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
|
||||||
|
EXPECT(assert_equal(expected2, pruned2));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, DotWithNames) {
|
TEST(DecisionTreeFactor, DotWithNames) {
|
||||||
DiscreteKey A(12, 3), B(5, 2);
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
|
|
Loading…
Reference in New Issue