From d21f191219a945c91546e8d77bb0badc3f877446 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 6 Nov 2024 15:06:46 -0500 Subject: [PATCH] use a fixed size min-heap to find the pruning threshold --- gtsam/discrete/DecisionTreeFactor.cpp | 57 ++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4e7a7342e..9b541bbf0 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -353,18 +353,57 @@ namespace gtsam { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { const size_t N = maxNrAssignments; - // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities = this->probabilities(); + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + std::vector min_heap; - // The number of probabilities can be lower than max_leaves - if (probabilities.size() <= N) { - return *this; - } + auto op = [&](const Assignment& a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } - std::sort(probabilities.begin(), probabilities.end(), - std::greater{}); + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); - double threshold = probabilities[N - 1]; + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardinalities_.at(k); + } + + if (min_heap.empty()) { + for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { + min_heap.push_back(p); + } + std::make_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + + } else { + // If p is larger than the smallest element, + // then we insert into the max heap. + if (p > min_heap.at(0)) { + for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { + if (min_heap.size() == N) { + std::pop_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + min_heap.pop_back(); + } + min_heap.push_back(p); + std::make_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + } + } + } + return p; + }; + this->visitWith(op); + + double threshold = min_heap.at(0); // Now threshold the decision tree size_t total = 0;