New prune method for DecisionTreeFactor

release/4.3a0
Varun Agrawal 2022-03-27 13:25:35 -04:00
parent a9a4075ff6
commit 4c966b9753
3 changed files with 71 additions and 0 deletions

View File

@ -286,5 +286,43 @@ namespace gtsam {
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
const size_t N = maxNrLeaves;
// Let's assume that the structure of the last discrete density will be the
// same as the last continuous
std::vector<double> probabilities;
// number of choices
this->visit([&](const double& prob) { probabilities.emplace_back(prob); });
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
return *this;
}
std::sort(probabilities.begin(), probabilities.end(),
std::greater<double>{});
double threshold = probabilities[N - 1];
// Now threshold the decision tree
size_t total = 0;
auto thresholdFunc = [threshold, &total, N](const double& value) {
if (value < threshold || total >= N) {
return 0.0;
} else {
total += 1;
return value;
}
};
DecisionTree<Key, double> thresholded(*this, thresholdFunc);
// Create pruned decision tree factor
DecisionTreeFactor prunedDiscreteFactor(this->discreteKeys(), thresholded);
return prunedDiscreteFactor;
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -170,6 +170,18 @@ namespace gtsam {
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/**
* @brief Prune the decision tree of discrete variables.
*
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability.
* A leaf is pruned if it is not in the top `maxNrLeaves` values.
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor prune(size_t maxNrLeaves) const;
/// @}
/// @name Wrapper support
/// @{

View File

@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Check pruning of the decision tree works as expected.
TEST(DecisionTreeFactor, Prune) {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
// Only keep the leaves with the top 5 values.
size_t maxNrLeaves = 5;
auto pruned5 = f.prune(maxNrLeaves);
// Pruned leaves should be 0
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
EXPECT(assert_equal(expected, pruned5));
// Check for more extreme pruning where we only keep the top 2 leaves
maxNrLeaves = 2;
auto pruned2 = f.prune(maxNrLeaves);
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
EXPECT(assert_equal(expected2, pruned2));
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2);