Better prune
parent
5b713032c1
commit
b70c63ee4c
|
@ -131,7 +131,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/// Calculate probability for given values `x`,
|
/// Calculate probability for given values `x`,
|
||||||
/// is just look up in AlgebraicDecisionTree.
|
/// is just look up in AlgebraicDecisionTree.
|
||||||
double evaluate(const DiscreteValues& values) const {
|
double evaluate(const Assignment<Key>& values) const {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ namespace gtsam {
|
||||||
return apply(f, safe_div);
|
return apply(f, safe_div);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decision tree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
|
|
@ -291,45 +291,24 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
const DecisionTreeFactor &discreteProbs) const {
|
const DecisionTreeFactor &discreteProbs) const {
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||||
|
discreteProbs.keys().end());
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
// Functional which loops over all assignments and create a set of
|
// Find maximum probability value for every combination of our keys.
|
||||||
// GaussianConditionals
|
Ordering keys(diff);
|
||||||
|
auto max = discreteProbs.max(keys);
|
||||||
|
|
||||||
|
// Check the max value for every combination of our keys.
|
||||||
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
auto pruner = [&](const Assignment<Key> &choices,
|
auto pruner = [&](const Assignment<Key> &choices,
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||||
const DiscreteValues values(choices);
|
|
||||||
|
|
||||||
// Case where the hybrid gaussian conditional has the same
|
|
||||||
// discrete keys as the decision tree.
|
|
||||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
|
||||||
return (discreteProbs(values) == 0.0) ? nullptr : conditional;
|
|
||||||
} else {
|
|
||||||
// TODO(Frank): It might be faster to "choose" based on values
|
|
||||||
// and then check whether the resulting tree has non-nullptrs.
|
|
||||||
std::vector<DiscreteKey> set_diff;
|
|
||||||
std::set_difference(
|
|
||||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
|
||||||
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
|
||||||
std::back_inserter(set_diff));
|
|
||||||
|
|
||||||
const std::vector<DiscreteValues> assignments =
|
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
|
||||||
DiscreteValues augmented_values(values);
|
|
||||||
augmented_values.insert(assignment);
|
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
|
||||||
// we need this conditional.
|
|
||||||
if (discreteProbs(augmented_values) > 0.0) {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If we are here, it means that all the sub-branches are 0, so we prune.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||||
|
|
Loading…
Reference in New Issue