Make prune functional
parent
caa3821b2b
commit
e15c44ec5c
|
@ -342,13 +342,57 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
const DecisionTreeFactor &discreteProbs) const {
|
||||||
|
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||||
|
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||||
|
|
||||||
// Functional which loops over all assignments and create a set of
|
// Functional which loops over all assignments and create a set of
|
||||||
// GaussianConditionals
|
// GaussianConditionals
|
||||||
auto pruner = prunerFunc(discreteProbs);
|
auto pruner = [&](const Assignment<Key> &choices,
|
||||||
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
|
-> GaussianConditional::shared_ptr {
|
||||||
|
// typecast so we can use this to get probability value
|
||||||
|
const DiscreteValues values(choices);
|
||||||
|
|
||||||
|
// Case where the hybrid gaussian conditional has the same
|
||||||
|
// discrete keys as the decision tree.
|
||||||
|
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
||||||
|
if (discreteProbs(values) == 0.0) {
|
||||||
|
// empty aka null pointer
|
||||||
|
std::shared_ptr<GaussianConditional> null;
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::vector<DiscreteKey> set_diff;
|
||||||
|
std::set_difference(
|
||||||
|
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||||
|
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
||||||
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
|
const std::vector<DiscreteValues> assignments =
|
||||||
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
|
DiscreteValues augmented_values(values);
|
||||||
|
augmented_values.insert(assignment);
|
||||||
|
|
||||||
|
// If any one of the sub-branches are non-zero,
|
||||||
|
// we need this conditional.
|
||||||
|
if (discreteProbs(augmented_values) > 0.0) {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we are here, it means that all the sub-branches are 0,
|
||||||
|
// so we prune.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
|
pruned_conditionals);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -225,8 +225,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
* `discreteProbs`.
|
* `discreteProbs`.
|
||||||
*
|
*
|
||||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||||
|
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||||
*/
|
*/
|
||||||
void prune(const DecisionTreeFactor &discreteProbs);
|
HybridGaussianConditional::shared_ptr prune(
|
||||||
|
const DecisionTreeFactor &discreteProbs) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue