use a fixed size min-heap to find the pruning threshold
parent
8b968c1401
commit
d21f191219
|
@ -353,18 +353,57 @@ namespace gtsam {
|
||||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||||
const size_t N = maxNrAssignments;
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Set of all keys
|
||||||
std::vector<double> probabilities = this->probabilities();
|
std::set<Key> allKeys(keys().begin(), keys().end());
|
||||||
|
std::vector<double> min_heap;
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves
|
auto op = [&](const Assignment<Key>& a, double p) {
|
||||||
if (probabilities.size() <= N) {
|
// Get all the keys in the current assignment
|
||||||
return *this;
|
std::set<Key> assignment_keys;
|
||||||
|
for (auto&& [k, _] : a) {
|
||||||
|
assignment_keys.insert(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(probabilities.begin(), probabilities.end(),
|
// Find the keys missing in the assignment
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||||
|
assignment_keys.begin(), assignment_keys.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
|
// Compute the total number of assignments in the (pruned) subtree
|
||||||
|
size_t nrAssignments = 1;
|
||||||
|
for (auto&& k : diff) {
|
||||||
|
nrAssignments *= cardinalities_.at(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (min_heap.empty()) {
|
||||||
|
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||||
|
min_heap.push_back(p);
|
||||||
|
}
|
||||||
|
std::make_heap(min_heap.begin(), min_heap.end(),
|
||||||
std::greater<double>{});
|
std::greater<double>{});
|
||||||
|
|
||||||
double threshold = probabilities[N - 1];
|
} else {
|
||||||
|
// If p is larger than the smallest element,
|
||||||
|
// then we insert into the max heap.
|
||||||
|
if (p > min_heap.at(0)) {
|
||||||
|
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||||
|
if (min_heap.size() == N) {
|
||||||
|
std::pop_heap(min_heap.begin(), min_heap.end(),
|
||||||
|
std::greater<double>{});
|
||||||
|
min_heap.pop_back();
|
||||||
|
}
|
||||||
|
min_heap.push_back(p);
|
||||||
|
std::make_heap(min_heap.begin(), min_heap.end(),
|
||||||
|
std::greater<double>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
};
|
||||||
|
this->visitWith(op);
|
||||||
|
|
||||||
|
double threshold = min_heap.at(0);
|
||||||
|
|
||||||
// Now threshold the decision tree
|
// Now threshold the decision tree
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
Loading…
Reference in New Issue