From dac84e99321f992620a1004e9f873b4c14c85024 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 10:04:00 -0400 Subject: [PATCH] update prune to new max number of assignments scheme --- gtsam/discrete/DecisionTreeFactor.cpp | 10 +++++++--- gtsam/discrete/DecisionTreeFactor.h | 7 ++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index acd4d4af2..4e16fc689 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -287,12 +287,16 @@ namespace gtsam { cardinalities_(keys.cardinalities()) {} /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { - const size_t N = maxNrLeaves; + DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; - this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); + this->visitLeaf([&](const Leaf& leaf) { + size_t nrAssignments = leaf.nrAssignments(); + double prob = leaf.constant(); + probabilities.insert(probabilities.end(), nrAssignments, prob); + }); // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 1f3d69292..286571ffc 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -175,12 +175,13 @@ namespace gtsam { * * Pruning will set the leaves to be "pruned" to 0 indicating a 0 * probability. - * A leaf is pruned if it is not in the top `maxNrLeaves` values. + * An assignment is pruned if it is not in the top `maxNrAssignments` + * values. * - * @param maxNrLeaves The maximum number of leaves to keep. + * @param maxNrAssignments The maximum number of assignments to keep. * @return DecisionTreeFactor */ - DecisionTreeFactor prune(size_t maxNrLeaves) const; + DecisionTreeFactor prune(size_t maxNrAssignments) const; /// @} /// @name Wrapper support