diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3c77e3f9a..703c657cf 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,94 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -/** - * @brief Helper function to get the pruner functional. - * - * @param prunedDiscreteProbs The prob. decision tree of only discrete keys. - * @param conditional Conditional to prune. Used to get full assignment. - * @return std::function &, double)> - */ -std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDiscreteProbs, - const HybridConditional &conditional) { - // Get the discrete keys as sets for the decision tree - // and the hybrid Gaussian conditional. - std::set discreteProbsKeySet = - DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys()); - std::set conditionalKeySet = - DiscreteKeysAsSet(conditional.discreteKeys()); - - auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet]( - const Assignment &choices, - double probability) -> double { - // This corresponds to 0 probability - double pruned_prob = 0.0; - - // typecast so we can use this to get probability value - DiscreteValues values(choices); - // Case where the hybrid Gaussian conditional has the same - // discrete keys as the decision tree. - if (conditionalKeySet == discreteProbsKeySet) { - if (prunedDiscreteProbs(values) == 0) { - return pruned_prob; - } else { - return probability; - } - } else { - // Due to branch merging (aka pruning) in DecisionTree, it is possible we - // get a `values` which doesn't have the full set of keys. - std::set valuesKeys; - for (auto kvp : values) { - valuesKeys.insert(kvp.first); - } - std::set conditionalKeys; - for (auto kvp : conditionalKeySet) { - conditionalKeys.insert(kvp.first); - } - // If true, then values is missing some keys - if (conditionalKeys != valuesKeys) { - // Get the keys present in conditionalKeys but not in valuesKeys - std::vector missing_keys; - std::set_difference(conditionalKeys.begin(), conditionalKeys.end(), - valuesKeys.begin(), valuesKeys.end(), - std::back_inserter(missing_keys)); - // Insert missing keys with a default assignment. - for (auto missing_key : missing_keys) { - values[missing_key] = 0; - } - } - - // Now we generate the full assignment by enumerating - // over all keys in the prunedDiscreteProbs. - // First we find the differing keys - std::vector set_diff; - std::set_difference(discreteProbsKeySet.begin(), - discreteProbsKeySet.end(), conditionalKeySet.begin(), - conditionalKeySet.end(), - std::back_inserter(set_diff)); - - // Now enumerate over all assignments of the differing keys - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this probability. - if (prunedDiscreteProbs(augmented_values) > 0.0) { - return probability; - } - } - // If we are here, it means that all the sub-branches are 0, - // so we prune. - return pruned_prob; - } - }; - return pruner; -} - /* ************************************************************************* */ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t maxNrLeaves) { @@ -164,9 +76,10 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( } /* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { + HybridBayesNet copy(*this); DecisionTreeFactor prunedDiscreteProbs = - this->pruneDiscreteConditionals(maxNrLeaves); + copy.pruneDiscreteConditionals(maxNrLeaves); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree @@ -179,13 +92,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Go through all the conditionals in the // Bayes Net and prune them as per prunedDiscreteProbs. - for (auto &&conditional : *this) { + for (auto &&conditional : copy) { if (auto gm = conditional->asHybrid()) { // Make a copy of the hybrid Gaussian conditional and prune it! - auto prunedHybridGaussianConditional = - std::make_shared(*gm); - prunedHybridGaussianConditional->prune( - prunedDiscreteProbs); // imperative :-( + auto prunedHybridGaussianConditional = gm->prune(prunedDiscreteProbs); // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back(prunedHybridGaussianConditional); @@ -336,10 +246,14 @@ AlgebraicDecisionTree HybridBayesNet::logProbability( }); } else if (auto dc = conditional->asDiscrete()) { // If discrete, add the discrete logProbability in the right branch - result = result.apply( - [dc](const Assignment &assignment, double leaf_value) { - return leaf_value + dc->logProbability(DiscreteValues(assignment)); - }); + if (result.nrLeaves() == 1) { + result = dc->errorTree().apply([](double error) { return -error; }); + } else { + result = result.apply([dc](const Assignment &assignment, + double leaf_value) { + return leaf_value + dc->logProbability(DiscreteValues(assignment)); + }); + } } } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 62688e8b2..9052a7a16 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -201,8 +201,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ HybridValues sample() const; - /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. - HybridBayesNet prune(size_t maxNrLeaves); + /** + * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. + * + * @param maxNrLeaves Continuous values at which to compute the error. + * @return A pruned HybridBayesNet + */ + HybridBayesNet prune(size_t maxNrLeaves) const; /** * @brief Compute conditional error for each discrete assignment,