PrunerFunc helper function
parent
3e151846ca
commit
a00bcbcac9
|
@ -49,6 +49,38 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||||
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Helper function to get the pruner functional.
|
||||||
|
*
|
||||||
|
* @param probDecisionTree The probability decision tree of only discrete keys.
|
||||||
|
* @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree.
|
||||||
|
* Pre-computed for efficiency.
|
||||||
|
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
|
||||||
|
* @return std::function<GaussianConditional::shared_ptr(
|
||||||
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
|
*/
|
||||||
|
std::function<GaussianConditional::shared_ptr(
|
||||||
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
|
PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
|
||||||
|
const std::set<DiscreteKey> &discreteFactorKeySet,
|
||||||
|
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
|
||||||
|
auto pruner = [&](const Assignment<Key> &choices,
|
||||||
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
|
-> GaussianConditional::shared_ptr {
|
||||||
|
// typecast so we can use this to get probability value
|
||||||
|
DiscreteValues values(choices);
|
||||||
|
|
||||||
|
if ((*probDecisionTree)(values) == 0.0) {
|
||||||
|
// empty aka null pointer
|
||||||
|
boost::shared_ptr<GaussianConditional> null;
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return pruner;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// Get the decision tree of only the discrete keys
|
// Get the decision tree of only the discrete keys
|
||||||
|
@ -57,6 +89,8 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
boost::make_shared<DecisionTreeFactor>(
|
boost::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals->prune(maxNrLeaves));
|
||||||
|
|
||||||
|
auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
|
@ -66,23 +100,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
|
|
||||||
HybridBayesNet prunedBayesNetFragment;
|
HybridBayesNet prunedBayesNetFragment;
|
||||||
|
|
||||||
// Functional which loops over all assignments and create a set of
|
|
||||||
// GaussianConditionals
|
|
||||||
auto pruner = [&](const Assignment<Key> &choices,
|
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
|
||||||
-> GaussianConditional::shared_ptr {
|
|
||||||
// typecast so we can use this to get probability value
|
|
||||||
DiscreteValues values(choices);
|
|
||||||
|
|
||||||
if ((*discreteFactor)(values) == 0.0) {
|
|
||||||
// empty aka null pointer
|
|
||||||
boost::shared_ptr<GaussianConditional> null;
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per discreteFactor.
|
// Bayes Net and prune them as per discreteFactor.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
@ -92,17 +109,19 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
|
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
|
||||||
|
|
||||||
if (gaussianMixture) {
|
if (gaussianMixture) {
|
||||||
// We may have mixtures with less discrete keys than discreteFactor so we
|
// We may have mixtures with less discrete keys than discreteFactor so
|
||||||
// skip those since the label assignment does not exist.
|
// we skip those since the label assignment does not exist.
|
||||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||||
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
|
if (gmKeySet != discreteFactorKeySet) {
|
||||||
if (gmKeySet != dfKeySet) {
|
|
||||||
// Add the gaussianMixture which doesn't have to be pruned.
|
// Add the gaussianMixture which doesn't have to be pruned.
|
||||||
prunedBayesNetFragment.push_back(
|
prunedBayesNetFragment.push_back(
|
||||||
boost::make_shared<HybridConditional>(gaussianMixture));
|
boost::make_shared<HybridConditional>(gaussianMixture));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the pruner function.
|
||||||
|
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
|
||||||
|
|
||||||
// Run the pruning to get a new, pruned tree
|
// Run the pruning to get a new, pruned tree
|
||||||
GaussianMixture::Conditionals prunedTree =
|
GaussianMixture::Conditionals prunedTree =
|
||||||
gaussianMixture->conditionals().apply(pruner);
|
gaussianMixture->conditionals().apply(pruner);
|
||||||
|
@ -173,7 +192,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
return gbn;
|
return gbn;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Solve for the MPE
|
// Solve for the MPE
|
||||||
DiscreteBayesNet discrete_bn;
|
DiscreteBayesNet discrete_bn;
|
||||||
|
@ -190,7 +209,7 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
return HybridValues(mpe, gbn.optimize());
|
return HybridValues(mpe, gbn.optimize());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* ************************************************************************* */
|
||||||
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn = this->choose(assignment);
|
GaussianBayesNet gbn = this->choose(assignment);
|
||||||
return gbn.optimize();
|
return gbn.optimize();
|
||||||
|
|
Loading…
Reference in New Issue