fix bug in prunerFunc due to missing keys in assignment
parent
a27979e84b
commit
80000b7e1b
|
|
@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
||||
* @param conditional Conditional to prune. Used to get full assignment.
|
||||
* @return std::function<double(const Assignment<Key> &, double)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &decisionTree,
|
||||
const DecisionTreeFactor &prunedDecisionTree,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the Gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
std::set<DiscreteKey> decisionTreeKeySet =
|
||||
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
||||
std::set<DiscreteKey> conditionalKeySet =
|
||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
|
|
@ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
// Case where the Gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == decisionTreeKeySet) {
|
||||
if (decisionTree(values) == 0) {
|
||||
if (prunedDecisionTree(values) == 0) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return probability;
|
||||
}
|
||||
} else {
|
||||
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
||||
// get a `values` which doesn't have the full set of keys.
|
||||
std::set<Key> valuesKeys;
|
||||
for (auto kvp : values) {
|
||||
valuesKeys.insert(kvp.first);
|
||||
}
|
||||
std::set<Key> conditionalKeys;
|
||||
for (auto kvp : conditionalKeySet) {
|
||||
conditionalKeys.insert(kvp.first);
|
||||
}
|
||||
// If true, then values is missing some keys
|
||||
if (conditionalKeys != valuesKeys) {
|
||||
// Get the keys present in conditionalKeys but not in valuesKeys
|
||||
std::vector<Key> missing_keys;
|
||||
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
||||
valuesKeys.begin(), valuesKeys.end(),
|
||||
std::back_inserter(missing_keys));
|
||||
// Insert missing keys with a default assignment.
|
||||
for (auto missing_key : missing_keys) {
|
||||
values[missing_key] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Now we generate the full assignment by enumerating
|
||||
// over all keys in the prunedDecisionTree.
|
||||
// First we find the differing keys
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
// Now enumerate over all assignments of the differing keys
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
|
|
@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (decisionTree(augmented_values) > 0.0) {
|
||||
if (prunedDecisionTree(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
|
|
@ -107,10 +136,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
||||
auto discrete = conditional->asDiscreteConditional();
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
|
||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||
auto discreteTree =
|
||||
|
|
@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
|
||||
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
|
||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
|
||||
|
|
|
|||
Loading…
Reference in New Issue