computeThreshold as an individual function

release/4.3a0
Varun Agrawal 2024-12-02 09:52:20 -05:00
parent 486feeb385
commit 8473911926
2 changed files with 32 additions and 9 deletions

View File

@ -407,11 +407,9 @@ namespace gtsam {
};
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
double DecisionTreeFactor::computeThreshold(const size_t N) const {
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());
std::set<Key> allKeys = this->labels();
MinHeap min_heap;
auto op = [&](const Assignment<Key>& a, double p) {
@ -433,18 +431,25 @@ namespace gtsam {
nrAssignments *= cardinalities_.at(k);
}
// If min-heap is empty, fill it initially.
// This is because there is nothing at the top.
if (min_heap.empty()) {
min_heap.push(p, std::min(nrAssignments, N));
} else {
// If p is larger than the smallest element,
// then we insert into the max heap.
if (p > min_heap.top()) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
// If p is larger than the smallest element,
// then we insert into the min heap.
// We check against the top each time because the
// heap maintains the smallest element at the top.
if (p > min_heap.top()) {
if (min_heap.size() == N) {
min_heap.pop();
}
min_heap.push(p);
} else {
// p is <= min value so move to the next one
break;
}
}
}
@ -452,7 +457,14 @@ namespace gtsam {
};
this->visitWith(op);
double threshold = min_heap.top();
return min_heap.top();
}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;
double threshold = computeThreshold(N);
// Now threshold the decision tree
size_t total = 0;

View File

@ -224,6 +224,17 @@ namespace gtsam {
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;
/**
* @brief Compute the probability value which is the threshold above which
* only `N` leaves are present.
*
* This is used for pruning out the smaller probabilities.
*
* @param N The number of leaves to keep post pruning.
* @return double
*/
double computeThreshold(const size_t N) const;
/**
* @brief Prune the decision tree of discrete variables.
*