Make prune functional

release/4.3a0
Frank Dellaert 2024-09-29 22:56:41 -07:00
parent caa3821b2b
commit e15c44ec5c
2 changed files with 50 additions and 4 deletions

View File

@ -342,13 +342,57 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) { HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
const DecisionTreeFactor &discreteProbs) const {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(discreteProbs); auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
const DiscreteValues values(choices);
// Case where the hybrid gaussian conditional has the same
// discrete keys as the decision tree.
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
if (discreteProbs(values) == 0.0) {
// empty aka null pointer
std::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment);
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (discreteProbs(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; return std::make_shared<HybridGaussianConditional>(discreteKeys(),
pruned_conditionals);
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -225,8 +225,10 @@ class GTSAM_EXPORT HybridGaussianConditional
* `discreteProbs`. * `discreteProbs`.
* *
* @param discreteProbs A pruned set of probabilities for the discrete keys. * @param discreteProbs A pruned set of probabilities for the discrete keys.
* @return Shared pointer to possibly a pruned HybridGaussianConditional
*/ */
void prune(const DecisionTreeFactor &discreteProbs); HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
/// @} /// @}