implement a min-heap to record the top N probabilities for pruning
parent
d21f191219
commit
9666725473
|
@ -349,13 +349,67 @@ namespace gtsam {
|
||||||
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
: DiscreteFactor(keys.indices(), keys.cardinalities()),
|
||||||
AlgebraicDecisionTree<Key>(keys, table) {}
|
AlgebraicDecisionTree<Key>(keys, table) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Min-Heap class to help with pruning.
|
||||||
|
* The `top` element is always the smallest value.
|
||||||
|
*/
|
||||||
|
class MinHeap {
|
||||||
|
std::vector<double> v_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Default constructor
|
||||||
|
MinHeap() {}
|
||||||
|
|
||||||
|
/// Push value onto the heap
|
||||||
|
void push(double x) {
|
||||||
|
v_.push_back(x);
|
||||||
|
std::make_heap(v_.begin(), v_.end(), std::greater<double>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Push value `x`, `n` number of times.
|
||||||
|
void push(double x, size_t n) {
|
||||||
|
v_.insert(v_.end(), n, x);
|
||||||
|
std::make_heap(v_.begin(), v_.end(), std::greater<double>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pop the top value of the heap.
|
||||||
|
double pop() {
|
||||||
|
std::pop_heap(v_.begin(), v_.end(), std::greater<double>{});
|
||||||
|
double x = v_.back();
|
||||||
|
v_.pop_back();
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the top value of the heap without popping it.
|
||||||
|
double top() { return v_.at(0); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Print the heap as a sequence.
|
||||||
|
*
|
||||||
|
* @param s A string to prologue the output.
|
||||||
|
*/
|
||||||
|
void print(const std::string& s = "") {
|
||||||
|
std::cout << (s.empty() ? "" : s + " ");
|
||||||
|
for (size_t i = 0; i < v_.size() - 1; i++) {
|
||||||
|
std::cout << v_.at(i) << ",";
|
||||||
|
}
|
||||||
|
std::cout << v_.at(v_.size() - 1) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if heap is empty.
|
||||||
|
bool empty() const { return v_.empty(); }
|
||||||
|
|
||||||
|
/// Return the size of the heap.
|
||||||
|
size_t size() const { return v_.size(); }
|
||||||
|
};
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||||
const size_t N = maxNrAssignments;
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
// Set of all keys
|
// Set of all keys
|
||||||
std::set<Key> allKeys(keys().begin(), keys().end());
|
std::set<Key> allKeys(keys().begin(), keys().end());
|
||||||
std::vector<double> min_heap;
|
MinHeap min_heap;
|
||||||
|
|
||||||
auto op = [&](const Assignment<Key>& a, double p) {
|
auto op = [&](const Assignment<Key>& a, double p) {
|
||||||
// Get all the keys in the current assignment
|
// Get all the keys in the current assignment
|
||||||
|
@ -377,25 +431,17 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (min_heap.empty()) {
|
if (min_heap.empty()) {
|
||||||
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
min_heap.push(p, std::min(nrAssignments, N));
|
||||||
min_heap.push_back(p);
|
|
||||||
}
|
|
||||||
std::make_heap(min_heap.begin(), min_heap.end(),
|
|
||||||
std::greater<double>{});
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// If p is larger than the smallest element,
|
// If p is larger than the smallest element,
|
||||||
// then we insert into the max heap.
|
// then we insert into the max heap.
|
||||||
if (p > min_heap.at(0)) {
|
if (p > min_heap.top()) {
|
||||||
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||||
if (min_heap.size() == N) {
|
if (min_heap.size() == N) {
|
||||||
std::pop_heap(min_heap.begin(), min_heap.end(),
|
min_heap.pop();
|
||||||
std::greater<double>{});
|
|
||||||
min_heap.pop_back();
|
|
||||||
}
|
}
|
||||||
min_heap.push_back(p);
|
min_heap.push(p);
|
||||||
std::make_heap(min_heap.begin(), min_heap.end(),
|
|
||||||
std::greater<double>{});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -403,7 +449,7 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
this->visitWith(op);
|
this->visitWith(op);
|
||||||
|
|
||||||
double threshold = min_heap.at(0);
|
double threshold = min_heap.top();
|
||||||
|
|
||||||
// Now threshold the decision tree
|
// Now threshold the decision tree
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
Loading…
Reference in New Issue