enumerate missing discrete keys so we can prune all gaussian mixtures

release/4.3a0
Varun Agrawal 2022-10-10 19:38:10 -04:00
parent a00bcbcac9
commit 2c8fe25842
1 changed files with 30 additions and 11 deletions

View File

@ -70,6 +70,9 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == discreteFactorKeySet) {
if ((*probDecisionTree)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
@ -77,6 +80,28 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
} else {
return conditional;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteFactorKeySet.begin(), discreteFactorKeySet.end(),
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if ((*probDecisionTree)(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
return pruner;
}
@ -112,12 +137,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// We may have mixtures with less discrete keys than discreteFactor so
// we skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
if (gmKeySet != discreteFactorKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}
// Get the pruner function.
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);