computeThreshold as an individual function
parent
486feeb385
commit
8473911926
|
@ -407,11 +407,9 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
double DecisionTreeFactor::computeThreshold(const size_t N) const {
|
||||||
const size_t N = maxNrAssignments;
|
|
||||||
|
|
||||||
// Set of all keys
|
// Set of all keys
|
||||||
std::set<Key> allKeys(keys().begin(), keys().end());
|
std::set<Key> allKeys = this->labels();
|
||||||
MinHeap min_heap;
|
MinHeap min_heap;
|
||||||
|
|
||||||
auto op = [&](const Assignment<Key>& a, double p) {
|
auto op = [&](const Assignment<Key>& a, double p) {
|
||||||
|
@ -433,18 +431,25 @@ namespace gtsam {
|
||||||
nrAssignments *= cardinalities_.at(k);
|
nrAssignments *= cardinalities_.at(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If min-heap is empty, fill it initially.
|
||||||
|
// This is because there is nothing at the top.
|
||||||
if (min_heap.empty()) {
|
if (min_heap.empty()) {
|
||||||
min_heap.push(p, std::min(nrAssignments, N));
|
min_heap.push(p, std::min(nrAssignments, N));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// If p is larger than the smallest element,
|
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
||||||
// then we insert into the max heap.
|
// If p is larger than the smallest element,
|
||||||
if (p > min_heap.top()) {
|
// then we insert into the min heap.
|
||||||
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
|
// We check against the top each time because the
|
||||||
|
// heap maintains the smallest element at the top.
|
||||||
|
if (p > min_heap.top()) {
|
||||||
if (min_heap.size() == N) {
|
if (min_heap.size() == N) {
|
||||||
min_heap.pop();
|
min_heap.pop();
|
||||||
}
|
}
|
||||||
min_heap.push(p);
|
min_heap.push(p);
|
||||||
|
} else {
|
||||||
|
// p is <= min value so move to the next one
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -452,7 +457,14 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
this->visitWith(op);
|
this->visitWith(op);
|
||||||
|
|
||||||
double threshold = min_heap.top();
|
return min_heap.top();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
|
||||||
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
|
double threshold = computeThreshold(N);
|
||||||
|
|
||||||
// Now threshold the decision tree
|
// Now threshold the decision tree
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
|
@ -224,6 +224,17 @@ namespace gtsam {
|
||||||
/// Get all the probabilities in order of assignment values
|
/// Get all the probabilities in order of assignment values
|
||||||
std::vector<double> probabilities() const;
|
std::vector<double> probabilities() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the probability value which is the threshold above which
|
||||||
|
* only `N` leaves are present.
|
||||||
|
*
|
||||||
|
* This is used for pruning out the smaller probabilities.
|
||||||
|
*
|
||||||
|
* @param N The number of leaves to keep post pruning.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double computeThreshold(const size_t N) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of discrete variables.
|
* @brief Prune the decision tree of discrete variables.
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue