add PrunerFunc to GaussianMixture
parent
2c8fe25842
commit
c15cfb6068
|
|
@ -122,31 +122,87 @@ void GaussianMixture::print(const std::string &s,
|
||||||
if (gf && !gf->empty()) {
|
if (gf && !gf->empty()) {
|
||||||
gf->print("", formatter);
|
gf->print("", formatter);
|
||||||
return rd.str();
|
return rd.str();
|
||||||
|
// return "Node()";
|
||||||
} else {
|
} else {
|
||||||
return "nullptr";
|
return "nullptr";
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* ************************************************************************* */
|
||||||
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
/// Return the DiscreteKey vector as a set.
|
||||||
// Functional which loops over all assignments and create a set of
|
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||||
// GaussianConditionals
|
std::set<DiscreteKey> s;
|
||||||
auto pruner = [&decisionTree](
|
s.insert(dkeys.begin(), dkeys.end());
|
||||||
const Assignment<Key> &choices,
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/**
|
||||||
|
* @brief Helper function to get the pruner functional.
|
||||||
|
*
|
||||||
|
* @param decisionTree The probability decision tree of only discrete keys.
|
||||||
|
* @param decisionTreeKeySet Set of DiscreteKeys in decisionTree.
|
||||||
|
* Pre-computed for efficiency.
|
||||||
|
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
|
||||||
|
* @return std::function<GaussianConditional::shared_ptr(
|
||||||
|
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
|
*/
|
||||||
|
std::function<GaussianConditional::shared_ptr(
|
||||||
|
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||||
|
PrunerFunc(const DecisionTreeFactor &decisionTree,
|
||||||
|
const std::set<DiscreteKey> &decisionTreeKeySet,
|
||||||
|
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
|
||||||
|
auto pruner = [&](const Assignment<Key> &choices,
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianConditional::shared_ptr {
|
-> GaussianConditional::shared_ptr {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues values(choices);
|
DiscreteValues values(choices);
|
||||||
|
|
||||||
if (decisionTree(values) == 0.0) {
|
// Case where the gaussian mixture has the same
|
||||||
// empty aka null pointer
|
// discrete keys as the decision tree.
|
||||||
boost::shared_ptr<GaussianConditional> null;
|
if (gaussianMixtureKeySet == decisionTreeKeySet) {
|
||||||
return null;
|
if (decisionTree(values) == 0.0) {
|
||||||
|
// empty aka null pointer
|
||||||
|
boost::shared_ptr<GaussianConditional> null;
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return conditional;
|
std::vector<DiscreteKey> set_diff;
|
||||||
|
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||||
|
gaussianMixtureKeySet.begin(),
|
||||||
|
gaussianMixtureKeySet.end(),
|
||||||
|
std::back_inserter(set_diff));
|
||||||
|
|
||||||
|
const std::vector<DiscreteValues> assignments =
|
||||||
|
DiscreteValues::CartesianProduct(set_diff);
|
||||||
|
for (const DiscreteValues &assignment : assignments) {
|
||||||
|
DiscreteValues augmented_values(values);
|
||||||
|
augmented_values.insert(assignment.begin(), assignment.end());
|
||||||
|
|
||||||
|
// If any one of the sub-branches are non-zero,
|
||||||
|
// we need this conditional.
|
||||||
|
if (decisionTree(augmented_values) > 0.0) {
|
||||||
|
return conditional;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we are here, it means that all the sub-branches are 0,
|
||||||
|
// so we prune.
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
return pruner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||||
|
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||||
|
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||||
|
// Functional which loops over all assignments and create a set of
|
||||||
|
// GaussianConditionals
|
||||||
|
auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet);
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
conditionals_.root_ = pruned_conditionals.root_;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue