computeThreshold as an individual function
parent
486feeb385
commit
8473911926
|
@ -407,11 +407,9 @@ namespace gtsam {
|
|||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||
const size_t N = maxNrAssignments;
|
||||
|
||||
double DecisionTreeFactor::computeThreshold(const size_t N) const {
|
||||
// Set of all keys
|
||||
std::set<Key> allKeys(keys().begin(), keys().end());
|
||||
std::set<Key> allKeys = this->labels();
|
||||
MinHeap min_heap;
|
||||
|
||||
auto op = [&](const Assignment<Key>& a, double p) {
|
||||
|
@ -433,18 +431,25 @@ namespace gtsam {
|
|||
nrAssignments *= cardinalities_.at(k);
|
||||
}
|
||||
|
||||
// If min-heap is empty, fill it initially.
|
||||
// This is because there is nothing at the top.
|
||||
if (min_heap.empty()) {
|
||||
min_heap.push(p, std::min(nrAssignments, N));
|
||||
|
||||
} else {
|
||||
// If p is larger than the smallest element,
|
||||
// then we insert into the max heap.
|
||||
if (p > min_heap.top()) {
|
||||
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||
// If p is larger than the smallest element,
|
||||
// then we insert into the min heap.
|
||||
// We check against the top each time because the
|
||||
// heap maintains the smallest element at the top.
|
||||
if (p > min_heap.top()) {
|
||||
if (min_heap.size() == N) {
|
||||
min_heap.pop();
|
||||
}
|
||||
min_heap.push(p);
|
||||
} else {
|
||||
// p is <= min value so move to the next one
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -452,7 +457,14 @@ namespace gtsam {
|
|||
};
|
||||
this->visitWith(op);
|
||||
|
||||
double threshold = min_heap.top();
|
||||
return min_heap.top();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||
const size_t N = maxNrAssignments;
|
||||
|
||||
double threshold = computeThreshold(N);
|
||||
|
||||
// Now threshold the decision tree
|
||||
size_t total = 0;
|
||||
|
|
|
@ -224,6 +224,17 @@ namespace gtsam {
|
|||
/// Get all the probabilities in order of assignment values
|
||||
std::vector<double> probabilities() const;
|
||||
|
||||
/**
|
||||
* @brief Compute the probability value which is the threshold above which
|
||||
* only `N` leaves are present.
|
||||
*
|
||||
* This is used for pruning out the smaller probabilities.
|
||||
*
|
||||
* @param N The number of leaves to keep post pruning.
|
||||
* @return double
|
||||
*/
|
||||
double computeThreshold(const size_t N) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of discrete variables.
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue