diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 3c5130f42..478f94f18 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -288,59 +288,6 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { return s; } -/* ************************************************************************* */ -std::function &, const GaussianConditional::shared_ptr &)> -HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) { - // Get the discrete keys as sets for the decision tree - // and the hybrid gaussian conditional. - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); - - auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet]( - const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - const DiscreteValues values(choices); - - // Case where the hybrid gaussian conditional has the same - // discrete keys as the decision tree. - if (hybridGaussianCondKeySet == discreteProbsKeySet) { - if (discreteProbs(values) == 0.0) { - // empty aka null pointer - std::shared_ptr null; - return null; - } else { - return conditional; - } - } else { - std::vector set_diff; - std::set_difference( - discreteProbsKeySet.begin(), discreteProbsKeySet.end(), - hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), - std::back_inserter(set_diff)); - - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this conditional. - if (discreteProbs(augmented_values) > 0.0) { - return conditional; - } - } - // If we are here, it means that all the sub-branches are 0, - // so we prune. - return nullptr; - } - }; - return pruner; -} - /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const DecisionTreeFactor &discreteProbs) const { @@ -358,14 +305,10 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( // Case where the hybrid gaussian conditional has the same // discrete keys as the decision tree. if (hybridGaussianCondKeySet == discreteProbsKeySet) { - if (discreteProbs(values) == 0.0) { - // empty aka null pointer - std::shared_ptr null; - return null; - } else { - return conditional; - } + return (discreteProbs(values) == 0.0) ? nullptr : conditional; } else { + // TODO(Frank): It might be faster to "choose" based on values + // and then check whether the resulting tree has non-nullptrs. std::vector set_diff; std::set_difference( discreteProbsKeySet.begin(), discreteProbsKeySet.end(), @@ -384,8 +327,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( return conditional; } } - // If we are here, it means that all the sub-branches are 0, - // so we prune. + // If we are here, it means that all the sub-branches are 0, so we prune. return nullptr; } }; diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index ede748b16..8f3aa6778 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -243,17 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional /// Convert to a DecisionTree of Gaussian factor graphs. GaussianFactorGraphTree asGaussianFactorGraphTree() const; - /** - * @brief Get the pruner function from discrete probabilities. - * - * @param discreteProbs The probabilities of only discrete keys. - * @return std::function &, const GaussianConditional::shared_ptr &)> - */ - std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &prunedProbabilities); - /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const;