enumerate missing discrete keys so we can prune all gaussian mixtures
parent
a00bcbcac9
commit
2c8fe25842
|
|
@ -70,6 +70,9 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
|
|||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (gaussianMixtureKeySet == discreteFactorKeySet) {
|
||||
if ((*probDecisionTree)(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
|
|
@ -77,6 +80,28 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
|
|||
} else {
|
||||
return conditional;
|
||||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(
|
||||
discreteFactorKeySet.begin(), discreteFactorKeySet.end(),
|
||||
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment.begin(), assignment.end());
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if ((*probDecisionTree)(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
|
@ -112,12 +137,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
// We may have mixtures with less discrete keys than discreteFactor so
|
||||
// we skip those since the label assignment does not exist.
|
||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||
if (gmKeySet != discreteFactorKeySet) {
|
||||
// Add the gaussianMixture which doesn't have to be pruned.
|
||||
prunedBayesNetFragment.push_back(
|
||||
boost::make_shared<HybridConditional>(gaussianMixture));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the pruner function.
|
||||
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
|
||||
|
|
|
|||
Loading…
Reference in New Issue