fix bug in prunerFunc due to missing keys in assignment

release/4.3a0
Varun Agrawal 2022-12-21 20:23:25 +05:30
parent a27979e84b
commit 80000b7e1b
1 changed files with 40 additions and 12 deletions

View File

@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const DecisionTreeFactor &prunedDecisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
std::set<DiscreteKey> decisionTreeKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// typecast so we can use this to get probability value
@ -67,17 +69,44 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
if (prunedDecisionTree(values) == 0) {
return 0.0;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}
// Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero,
// we need this probability.
if (decisionTree(augmented_values) > 0.0) {
if (prunedDecisionTree(augmented_values) > 0.0) {
return probability;
}
}
@ -107,10 +136,7 @@ void HybridBayesNet::updateDiscreteConditionals(
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
// Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree =
@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);