diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9b541bbf0..c8efc5fa5 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -349,13 +349,67 @@ namespace gtsam { : DiscreteFactor(keys.indices(), keys.cardinalities()), AlgebraicDecisionTree(keys, table) {} + /** + * @brief Min-Heap class to help with pruning. + * The `top` element is always the smallest value. + */ + class MinHeap { + std::vector v_; + + public: + /// Default constructor + MinHeap() {} + + /// Push value onto the heap + void push(double x) { + v_.push_back(x); + std::make_heap(v_.begin(), v_.end(), std::greater{}); + } + + /// Push value `x`, `n` number of times. + void push(double x, size_t n) { + v_.insert(v_.end(), n, x); + std::make_heap(v_.begin(), v_.end(), std::greater{}); + } + + /// Pop the top value of the heap. + double pop() { + std::pop_heap(v_.begin(), v_.end(), std::greater{}); + double x = v_.back(); + v_.pop_back(); + return x; + } + + /// Return the top value of the heap without popping it. + double top() { return v_.at(0); } + + /** + * @brief Print the heap as a sequence. + * + * @param s A string to prologue the output. + */ + void print(const std::string& s = "") { + std::cout << (s.empty() ? "" : s + " "); + for (size_t i = 0; i < v_.size() - 1; i++) { + std::cout << v_.at(i) << ","; + } + std::cout << v_.at(v_.size() - 1) << std::endl; + } + + /// Return true if heap is empty. + bool empty() const { return v_.empty(); } + + /// Return the size of the heap. + size_t size() const { return v_.size(); } + }; + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { const size_t N = maxNrAssignments; // Set of all keys std::set allKeys(keys().begin(), keys().end()); - std::vector min_heap; + MinHeap min_heap; auto op = [&](const Assignment& a, double p) { // Get all the keys in the current assignment @@ -377,25 +431,17 @@ namespace gtsam { } if (min_heap.empty()) { - for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { - min_heap.push_back(p); - } - std::make_heap(min_heap.begin(), min_heap.end(), - std::greater{}); + min_heap.push(p, std::min(nrAssignments, N)); } else { // If p is larger than the smallest element, // then we insert into the max heap. - if (p > min_heap.at(0)) { + if (p > min_heap.top()) { for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { if (min_heap.size() == N) { - std::pop_heap(min_heap.begin(), min_heap.end(), - std::greater{}); - min_heap.pop_back(); + min_heap.pop(); } - min_heap.push_back(p); - std::make_heap(min_heap.begin(), min_heap.end(), - std::greater{}); + min_heap.push(p); } } } @@ -403,7 +449,7 @@ namespace gtsam { }; this->visitWith(op); - double threshold = min_heap.at(0); + double threshold = min_heap.top(); // Now threshold the decision tree size_t total = 0;