diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 163e77e47..96a6dfd63 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -70,12 +70,37 @@ PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree, // typecast so we can use this to get probability value DiscreteValues values(choices); - if ((*probDecisionTree)(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; + // Case where the gaussian mixture has the same + // discrete keys as the decision tree. + if (gaussianMixtureKeySet == discreteFactorKeySet) { + if ((*probDecisionTree)(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return conditional; + } } else { - return conditional; + std::vector set_diff; + std::set_difference( + discreteFactorKeySet.begin(), discreteFactorKeySet.end(), + gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment.begin(), assignment.end()); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if ((*probDecisionTree)(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; } }; return pruner; @@ -112,12 +137,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // We may have mixtures with less discrete keys than discreteFactor so // we skip those since the label assignment does not exist. auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); - if (gmKeySet != discreteFactorKeySet) { - // Add the gaussianMixture which doesn't have to be pruned. - prunedBayesNetFragment.push_back( - boost::make_shared(gaussianMixture)); - continue; - } // Get the pruner function. auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);