Make prune functional
parent
caa3821b2b
commit
e15c44ec5c
|
@ -342,13 +342,57 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
|
||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||
const DecisionTreeFactor &discreteProbs) const {
|
||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = prunerFunc(discreteProbs);
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
const DiscreteValues values(choices);
|
||||
|
||||
// Case where the hybrid gaussian conditional has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
||||
if (discreteProbs(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
std::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return conditional;
|
||||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(
|
||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment);
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if (discreteProbs(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||
pruned_conditionals);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -225,8 +225,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* `discreteProbs`.
|
||||
*
|
||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||
*/
|
||||
void prune(const DecisionTreeFactor &discreteProbs);
|
||||
HybridGaussianConditional::shared_ptr prune(
|
||||
const DecisionTreeFactor &discreteProbs) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
Loading…
Reference in New Issue