Better prune

release/4.3a0
Frank Dellaert 2024-10-01 13:32:23 -07:00
parent 5b713032c1
commit b70c63ee4c
2 changed files with 16 additions and 37 deletions

View File

@ -131,7 +131,7 @@ namespace gtsam {
/// Calculate probability for given values `x`,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const DiscreteValues& values) const {
double evaluate(const Assignment<Key>& values) const {
return ADT::operator()(values);
}
@ -155,7 +155,7 @@ namespace gtsam {
return apply(f, safe_div);
}
/// Convert into a decisiontree
/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
/// Create new factor by summing all values with the same separator values

View File

@ -291,45 +291,24 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/* *******************************************************************************/
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DecisionTreeFactor &discreteProbs) const {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Functional which loops over all assignments and create a set of
// GaussianConditionals
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
const DiscreteValues values(choices);
// Case where the hybrid gaussian conditional has the same
// discrete keys as the decision tree.
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
return (discreteProbs(values) == 0.0) ? nullptr : conditional;
} else {
// TODO(Frank): It might be faster to "choose" based on values
// and then check whether the resulting tree has non-nullptrs.
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (discreteProbs(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0, so we prune.
return nullptr;
}
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
};
auto pruned_conditionals = conditionals_.apply(pruner);