enumerate missing discrete keys so we can prune all gaussian mixtures
parent
a00bcbcac9
commit
2c8fe25842
|
|
@ -70,12 +70,37 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
|
|
||||||
if ((*probDecisionTree)(values) == 0.0) {
|
// Case where the gaussian mixture has the same
|
||||||
// empty aka null pointer
|
// discrete keys as the decision tree.
|
||||||
boost::shared_ptr<GaussianConditional> null;
|
if (gaussianMixtureKeySet == discreteFactorKeySet) {
|
||||||
return null;
|
if ((*probDecisionTree)(values) == 0.0) {
|
||||||
|
// empty aka null pointer
|
||||||
|
boost::shared_ptr<GaussianConditional> null;
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return conditional;
|
std::vector<DiscreteKey> set_diff;
|
||||||
|
std::set_difference(
|
||||||
|
discreteFactorKeySet.begin(), discreteFactorKeySet.end(),
|
||||||
|
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
|
||||||
|
std::back_inserter(set_diff));
|
||||||
|
const std::vector<DiscreteValues> assignments =
|
||||||
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
|
DiscreteValues augmented_values(values);
|
||||||
|
augmented_values.insert(assignment.begin(), assignment.end());
|
||||||
|
|
||||||
|
// If any one of the sub-branches are non-zero,
|
||||||
|
// we need this conditional.
|
||||||
|
if ((*probDecisionTree)(augmented_values) > 0.0) {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we are here, it means that all the sub-branches are 0,
|
||||||
|
// so we prune.
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return pruner;
|
return pruner;
|
||||||
|
|
@ -112,12 +137,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// We may have mixtures with less discrete keys than discreteFactor so
|
// We may have mixtures with less discrete keys than discreteFactor so
|
||||||
// we skip those since the label assignment does not exist.
|
// we skip those since the label assignment does not exist.
|
||||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||||
if (gmKeySet != discreteFactorKeySet) {
|
|
||||||
// Add the gaussianMixture which doesn't have to be pruned.
|
|
||||||
prunedBayesNetFragment.push_back(
|
|
||||||
boost::make_shared<HybridConditional>(gaussianMixture));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the pruner function.
|
// Get the pruner function.
|
||||||
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
|
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue