update prune to new max number of assignments scheme

release/4.3a0
Varun Agrawal 2022-03-31 10:04:00 -04:00
parent 039ecfc3c3
commit dac84e9932
2 changed files with 11 additions and 6 deletions

View File

@ -287,12 +287,16 @@ namespace gtsam {
cardinalities_(keys.cardinalities()) {} cardinalities_(keys.cardinalities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrLeaves; const size_t N = maxNrAssignments;
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities;
this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -175,12 +175,13 @@ namespace gtsam {
* *
* Pruning will set the leaves to be "pruned" to 0 indicating a 0 * Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability. * probability.
* A leaf is pruned if it is not in the top `maxNrLeaves` values. * An assignment is pruned if it is not in the top `maxNrAssignments`
* values.
* *
* @param maxNrLeaves The maximum number of leaves to keep. * @param maxNrAssignments The maximum number of assignments to keep.
* @return DecisionTreeFactor * @return DecisionTreeFactor
*/ */
DecisionTreeFactor prune(size_t maxNrLeaves) const; DecisionTreeFactor prune(size_t maxNrAssignments) const;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support