add PrunerFunc to GaussianMixture

release/4.3a0
Varun Agrawal 2022-10-11 12:10:02 -04:00
parent 2c8fe25842
commit c15cfb6068
1 changed files with 67 additions and 11 deletions

View File

@ -122,31 +122,87 @@ void GaussianMixture::print(const std::string &s,
if (gf && !gf->empty()) { if (gf && !gf->empty()) {
gf->print("", formatter); gf->print("", formatter);
return rd.str(); return rd.str();
// return "Node()";
} else { } else {
return "nullptr"; return "nullptr";
} }
}); });
} }
/* *******************************************************************************/ /* ************************************************************************* */
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /// Return the DiscreteKey vector as a set.
// Functional which loops over all assignments and create a set of static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
// GaussianConditionals std::set<DiscreteKey> s;
auto pruner = [&decisionTree]( s.insert(dkeys.begin(), dkeys.end());
const Assignment<Key> &choices, return s;
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @param decisionTreeKeySet Set of DiscreteKeys in decisionTree.
* Pre-computed for efficiency.
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
PrunerFunc(const DecisionTreeFactor &decisionTree,
const std::set<DiscreteKey> &decisionTreeKeySet,
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
if (decisionTree(values) == 0.0) { // Case where the gaussian mixture has the same
// empty aka null pointer // discrete keys as the decision tree.
boost::shared_ptr<GaussianConditional> null; if (gaussianMixtureKeySet == decisionTreeKeySet) {
return null; if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else { } else {
return conditional; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
} }
}; };
return pruner;
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;