diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f5c11c6e1..04636f74e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,19 +47,21 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { /** * @brief Helper function to get the pruner functional. * - * @param decisionTree The probability decision tree of only discrete keys. - * @return std::function &, const GaussianConditional::shared_ptr &)> + * @param prunedDecisionTree The prob. decision tree of only discrete keys. + * @param conditional Conditional to prune. Used to get full assignment. + * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &decisionTree, + const DecisionTreeFactor &prunedDecisionTree, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); - auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); + std::set decisionTreeKeySet = + DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); + std::set conditionalKeySet = + DiscreteKeysAsSet(conditional.discreteKeys()); - auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet]( + auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( const Assignment &choices, double probability) -> double { // typecast so we can use this to get probability value @@ -67,17 +69,44 @@ std::function &, double)> prunerFunc( // Case where the Gaussian mixture has the same // discrete keys as the decision tree. if (conditionalKeySet == decisionTreeKeySet) { - if (decisionTree(values) == 0) { + if (prunedDecisionTree(values) == 0) { return 0.0; } else { return probability; } } else { + // Due to branch merging (aka pruning) in DecisionTree, it is possible we + // get a `values` which doesn't have the full set of keys. + std::set valuesKeys; + for (auto kvp : values) { + valuesKeys.insert(kvp.first); + } + std::set conditionalKeys; + for (auto kvp : conditionalKeySet) { + conditionalKeys.insert(kvp.first); + } + // If true, then values is missing some keys + if (conditionalKeys != valuesKeys) { + // Get the keys present in conditionalKeys but not in valuesKeys + std::vector missing_keys; + std::set_difference(conditionalKeys.begin(), conditionalKeys.end(), + valuesKeys.begin(), valuesKeys.end(), + std::back_inserter(missing_keys)); + // Insert missing keys with a default assignment. + for (auto missing_key : missing_keys) { + values[missing_key] = 0; + } + } + + // Now we generate the full assignment by enumerating + // over all keys in the prunedDecisionTree. + // First we find the differing keys std::vector set_diff; std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), conditionalKeySet.begin(), conditionalKeySet.end(), std::back_inserter(set_diff)); + // Now enumerate over all assignments of the differing keys const std::vector assignments = DiscreteValues::CartesianProduct(set_diff); for (const DiscreteValues &assignment : assignments) { @@ -86,7 +115,7 @@ std::function &, double)> prunerFunc( // If any one of the sub-branches are non-zero, // we need this probability. - if (decisionTree(augmented_values) > 0.0) { + if (prunedDecisionTree(augmented_values) > 0.0) { return probability; } } @@ -107,10 +136,7 @@ void HybridBayesNet::updateDiscreteConditionals( for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); if (conditional->isDiscrete()) { - // std::cout << demangle(typeid(conditional).name()) << std::endl; auto discrete = conditional->asDiscreteConditional(); - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); // Apply prunerFunc to the underlying AlgebraicDecisionTree auto discreteTree = @@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals( discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional)); // Create the new (hybrid) conditional + KeyVector frontals(discrete->frontals().begin(), + discrete->frontals().end()); auto prunedDiscrete = boost::make_shared( frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); conditional = boost::make_shared(prunedDiscrete);