Kill obsolete prunerFunc
parent
e15c44ec5c
commit
d2880e9913
|
@ -37,94 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
|
||||
* @param conditional Conditional to prune. Used to get full assignment.
|
||||
* @return std::function<double(const Assignment<Key> &, double)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &prunedDiscreteProbs,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the hybrid Gaussian conditional.
|
||||
std::set<DiscreteKey> discreteProbsKeySet =
|
||||
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
|
||||
std::set<DiscreteKey> conditionalKeySet =
|
||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
|
||||
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// This corresponds to 0 probability
|
||||
double pruned_prob = 0.0;
|
||||
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
// Case where the hybrid Gaussian conditional has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == discreteProbsKeySet) {
|
||||
if (prunedDiscreteProbs(values) == 0) {
|
||||
return pruned_prob;
|
||||
} else {
|
||||
return probability;
|
||||
}
|
||||
} else {
|
||||
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||
// get a `values` which doesn't have the full set of keys.
|
||||
std::set<Key> valuesKeys;
|
||||
for (auto kvp : values) {
|
||||
valuesKeys.insert(kvp.first);
|
||||
}
|
||||
std::set<Key> conditionalKeys;
|
||||
for (auto kvp : conditionalKeySet) {
|
||||
conditionalKeys.insert(kvp.first);
|
||||
}
|
||||
// If true, then values is missing some keys
|
||||
if (conditionalKeys != valuesKeys) {
|
||||
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||
std::vector<Key> missing_keys;
|
||||
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||
valuesKeys.begin(), valuesKeys.end(),
|
||||
std::back_inserter(missing_keys));
|
||||
// Insert missing keys with a default assignment.
|
||||
for (auto missing_key : missing_keys) {
|
||||
values[missing_key] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now we generate the full assignment by enumerating
|
||||
// over all keys in the prunedDiscreteProbs.
|
||||
// First we find the differing keys
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(discreteProbsKeySet.begin(),
|
||||
discreteProbsKeySet.end(), conditionalKeySet.begin(),
|
||||
conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
// Now enumerate over all assignments of the differing keys
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment);
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (prunedDiscreteProbs(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return pruned_prob;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||
size_t maxNrLeaves) {
|
||||
|
@ -164,9 +76,10 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
HybridBayesNet copy(*this);
|
||||
DecisionTreeFactor prunedDiscreteProbs =
|
||||
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||
copy.pruneDiscreteConditionals(maxNrLeaves);
|
||||
|
||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
|
@ -179,13 +92,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||
for (auto &&conditional : *this) {
|
||||
for (auto &&conditional : copy) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
// Make a copy of the hybrid Gaussian conditional and prune it!
|
||||
auto prunedHybridGaussianConditional =
|
||||
std::make_shared<HybridGaussianConditional>(*gm);
|
||||
prunedHybridGaussianConditional->prune(
|
||||
prunedDiscreteProbs); // imperative :-(
|
||||
auto prunedHybridGaussianConditional = gm->prune(prunedDiscreteProbs);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);
|
||||
|
@ -336,10 +246,14 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
|||
});
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If discrete, add the discrete logProbability in the right branch
|
||||
result = result.apply(
|
||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
||||
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
||||
});
|
||||
if (result.nrLeaves() == 1) {
|
||||
result = dc->errorTree().apply([](double error) { return -error; });
|
||||
} else {
|
||||
result = result.apply([dc](const Assignment<Key> &assignment,
|
||||
double leaf_value) {
|
||||
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -201,8 +201,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
HybridValues sample() const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves);
|
||||
/**
|
||||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
*
|
||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/**
|
||||
* @brief Compute conditional error for each discrete assignment,
|
||||
|
|
Loading…
Reference in New Issue