Inline lambda
parent
d2880e9913
commit
28f5ed0a6e
|
@ -288,59 +288,6 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
std::function<GaussianConditional::shared_ptr(
|
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
|
||||||
// Get the discrete keys as sets for the decision tree
|
|
||||||
// and the hybrid gaussian conditional.
|
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
|
||||||
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
|
||||||
|
|
||||||
auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet](
|
|
||||||
const Assignment<Key> &choices,
|
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
|
||||||
-> GaussianConditional::shared_ptr {
|
|
||||||
// typecast so we can use this to get probability value
|
|
||||||
const DiscreteValues values(choices);
|
|
||||||
|
|
||||||
// Case where the hybrid gaussian conditional has the same
|
|
||||||
// discrete keys as the decision tree.
|
|
||||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
|
||||||
if (discreteProbs(values) == 0.0) {
|
|
||||||
// empty aka null pointer
|
|
||||||
std::shared_ptr<GaussianConditional> null;
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::vector<DiscreteKey> set_diff;
|
|
||||||
std::set_difference(
|
|
||||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
|
||||||
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
|
||||||
std::back_inserter(set_diff));
|
|
||||||
|
|
||||||
const std::vector<DiscreteValues> assignments =
|
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
|
||||||
DiscreteValues augmented_values(values);
|
|
||||||
augmented_values.insert(assignment);
|
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
|
||||||
// we need this conditional.
|
|
||||||
if (discreteProbs(augmented_values) > 0.0) {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If we are here, it means that all the sub-branches are 0,
|
|
||||||
// so we prune.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return pruner;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
const DecisionTreeFactor &discreteProbs) const {
|
const DecisionTreeFactor &discreteProbs) const {
|
||||||
|
@ -358,14 +305,10 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
// Case where the hybrid gaussian conditional has the same
|
// Case where the hybrid gaussian conditional has the same
|
||||||
// discrete keys as the decision tree.
|
// discrete keys as the decision tree.
|
||||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
||||||
if (discreteProbs(values) == 0.0) {
|
return (discreteProbs(values) == 0.0) ? nullptr : conditional;
|
||||||
// empty aka null pointer
|
|
||||||
std::shared_ptr<GaussianConditional> null;
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
// TODO(Frank): It might be faster to "choose" based on values
|
||||||
|
// and then check whether the resulting tree has non-nullptrs.
|
||||||
std::vector<DiscreteKey> set_diff;
|
std::vector<DiscreteKey> set_diff;
|
||||||
std::set_difference(
|
std::set_difference(
|
||||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||||
|
@ -384,8 +327,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
return conditional;
|
return conditional;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If we are here, it means that all the sub-branches are 0,
|
// If we are here, it means that all the sub-branches are 0, so we prune.
|
||||||
// so we prune.
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -243,17 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
/// Convert to a DecisionTree of Gaussian factor graphs.
|
/// Convert to a DecisionTree of Gaussian factor graphs.
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the pruner function from discrete probabilities.
|
|
||||||
*
|
|
||||||
* @param discreteProbs The probabilities of only discrete keys.
|
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
*/
|
|
||||||
std::function<GaussianConditional::shared_ptr(
|
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
prunerFunc(const DecisionTreeFactor &prunedProbabilities);
|
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue