diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 3a654ddad..064905d6b 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -122,31 +122,87 @@ void GaussianMixture::print(const std::string &s, if (gf && !gf->empty()) { gf->print("", formatter); return rd.str(); + // return "Node()"; } else { return "nullptr"; } }); } -/* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { - // Functional which loops over all assignments and create a set of - // GaussianConditionals - auto pruner = [&decisionTree]( - const Assignment &choices, +/* ************************************************************************* */ +/// Return the DiscreteKey vector as a set. +static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { + std::set s; + s.insert(dkeys.begin(), dkeys.end()); + return s; +} + +/* ************************************************************************* */ +/** + * @brief Helper function to get the pruner functional. + * + * @param decisionTree The probability decision tree of only discrete keys. + * @param decisionTreeKeySet Set of DiscreteKeys in decisionTree. + * Pre-computed for efficiency. + * @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture. + * @return std::function &, const GaussianConditional::shared_ptr &)> + */ +std::function &, const GaussianConditional::shared_ptr &)> +PrunerFunc(const DecisionTreeFactor &decisionTree, + const std::set &decisionTreeKeySet, + const std::set &gaussianMixtureKeySet) { + auto pruner = [&](const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { // typecast so we can use this to get probability value DiscreteValues values(choices); - if (decisionTree(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; + // Case where the gaussian mixture has the same + // discrete keys as the decision tree. + if (gaussianMixtureKeySet == decisionTreeKeySet) { + if (decisionTree(values) == 0.0) { + // empty aka null pointer + boost::shared_ptr null; + return null; + } else { + return conditional; + } } else { - return conditional; + std::vector set_diff; + std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), + gaussianMixtureKeySet.begin(), + gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); + + const std::vector assignments = + DiscreteValues::CartesianProduct(set_diff); + for (const DiscreteValues &assignment : assignments) { + DiscreteValues augmented_values(values); + augmented_values.insert(assignment.begin(), assignment.end()); + + // If any one of the sub-branches are non-zero, + // we need this conditional. + if (decisionTree(augmented_values) > 0.0) { + return conditional; + } + } + // If we are here, it means that all the sub-branches are 0, + // so we prune. + return nullptr; } }; + return pruner; +} + +/* *******************************************************************************/ +void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { + auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Functional which loops over all assignments and create a set of + // GaussianConditionals + auto pruner = PrunerFunc(decisionTree, decisionTreeKeySet, gmKeySet); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_;