Kill obsolete prunerFunc

release/4.3a0
Frank Dellaert 2024-09-29 23:07:36 -07:00
parent e15c44ec5c
commit d2880e9913
2 changed files with 20 additions and 101 deletions

View File

@ -37,94 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the hybrid Gaussian conditional.
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// This corresponds to 0 probability
double pruned_prob = 0.0;
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the hybrid Gaussian conditional has the same
// discrete keys as the decision tree.
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
return pruned_prob;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDiscreteProbs.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(discreteProbsKeySet.begin(),
discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero,
// we need this probability.
if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return pruned_prob;
}
};
return pruner;
}
/* ************************************************************************* */
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
@ -164,9 +76,10 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet copy(*this);
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);
copy.pruneDiscreteConditionals(maxNrLeaves);
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
@ -179,13 +92,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
for (auto &&conditional : copy) {
if (auto gm = conditional->asHybrid()) {
// Make a copy of the hybrid Gaussian conditional and prune it!
auto prunedHybridGaussianConditional =
std::make_shared<HybridGaussianConditional>(*gm);
prunedHybridGaussianConditional->prune(
prunedDiscreteProbs); // imperative :-(
auto prunedHybridGaussianConditional = gm->prune(prunedDiscreteProbs);
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);
@ -336,10 +246,14 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
});
} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete logProbability in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
if (result.nrLeaves() == 1) {
result = dc->errorTree().apply([](double error) { return -error; });
} else {
result = result.apply([dc](const Assignment<Key> &assignment,
double leaf_value) {
return leaf_value + dc->logProbability(DiscreteValues(assignment));
});
}
}
}

View File

@ -201,8 +201,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
HybridValues sample() const;
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);
/**
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
*
* @param maxNrLeaves Continuous values at which to compute the error.
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(size_t maxNrLeaves) const;
/**
* @brief Compute conditional error for each discrete assignment,