PrunerFunc helper function

release/4.3a0
Varun Agrawal 2022-10-10 16:03:36 -04:00
parent 3e151846ca
commit a00bcbcac9
1 changed files with 42 additions and 23 deletions

View File

@ -49,6 +49,38 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
return boost::make_shared<DecisionTreeFactor>(dtFactor); return boost::make_shared<DecisionTreeFactor>(dtFactor);
} }
/**
* @brief Helper function to get the pruner functional.
*
* @param probDecisionTree The probability decision tree of only discrete keys.
* @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree.
* Pre-computed for efficiency.
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
const std::set<DiscreteKey> &discreteFactorKeySet,
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*probDecisionTree)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};
return pruner;
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
@ -57,6 +89,8 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
boost::make_shared<DecisionTreeFactor>( boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves)); discreteConditionals->prune(maxNrLeaves));
auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
@ -66,23 +100,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor. // Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
@ -92,17 +109,19 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner()); boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
if (gaussianMixture) { if (gaussianMixture) {
// We may have mixtures with less discrete keys than discreteFactor so we // We may have mixtures with less discrete keys than discreteFactor so
// skip those since the label assignment does not exist. // we skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); if (gmKeySet != discreteFactorKeySet) {
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned. // Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture)); boost::make_shared<HybridConditional>(gaussianMixture));
continue; continue;
} }
// Get the pruner function.
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
// Run the pruning to get a new, pruned tree // Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree = GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner); gaussianMixture->conditionals().apply(pruner);
@ -173,7 +192,7 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
/* *******************************************************************************/ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Solve for the MPE
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
@ -190,7 +209,7 @@ HybridValues HybridBayesNet::optimize() const {
return HybridValues(mpe, gbn.optimize()); return HybridValues(mpe, gbn.optimize());
} }
/* *******************************************************************************/ /* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment); GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize(); return gbn.optimize();