use a fixed size min-heap to find the pruning threshold

release/4.3a0
Varun Agrawal 2024-11-06 15:06:46 -05:00
parent 8b968c1401
commit d21f191219
1 changed files with 48 additions and 9 deletions

View File

@ -353,18 +353,57 @@ namespace gtsam {
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities = this->probabilities();
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());
std::vector<double> min_heap;
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
return *this;
auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : a) {
assignment_keys.insert(k);
}
std::sort(probabilities.begin(), probabilities.end(),
// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));
// Compute the total number of assignments in the (pruned) subtree
size_t nrAssignments = 1;
for (auto&& k : diff) {
nrAssignments *= cardinalities_.at(k);
}
if (min_heap.empty()) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
min_heap.push_back(p);
}
std::make_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
double threshold = probabilities[N - 1];
} else {
// If p is larger than the smallest element,
// then we insert into the max heap.
if (p > min_heap.at(0)) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
if (min_heap.size() == N) {
std::pop_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
min_heap.pop_back();
}
min_heap.push_back(p);
std::make_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
}
}
}
return p;
};
this->visitWith(op);
double threshold = min_heap.at(0);
// Now threshold the decision tree
size_t total = 0;