diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 784b11e51..7ed116016 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -131,7 +131,7 @@ namespace gtsam { /// Calculate probability for given values `x`, /// is just look up in AlgebraicDecisionTree. - double evaluate(const DiscreteValues& values) const { + double evaluate(const Assignment& values) const { return ADT::operator()(values); } @@ -155,7 +155,7 @@ namespace gtsam { return apply(f, safe_div); } - /// Convert into a decisiontree + /// Convert into a decision tree DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 1c3a69ce7..2c0fb28a4 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -291,45 +291,24 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const DecisionTreeFactor &discreteProbs) const { - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys()); + // Find keys in discreteProbs.keys() but not in this->keys(): + std::set mine(this->keys().begin(), this->keys().end()); + std::set theirs(discreteProbs.keys().begin(), + discreteProbs.keys().end()); + std::vector diff; + std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), + std::back_inserter(diff)); - // Functional which loops over all assignments and create a set of - // GaussianConditionals + // Find maximum probability value for every combination of our keys. + Ordering keys(diff); + auto max = discreteProbs.max(keys); + + // Check the max value for every combination of our keys. + // If the max value is 0.0, we can prune the corresponding conditional. auto pruner = [&](const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - const DiscreteValues values(choices); - - // Case where the hybrid gaussian conditional has the same - // discrete keys as the decision tree. - if (hybridGaussianCondKeySet == discreteProbsKeySet) { - return (discreteProbs(values) == 0.0) ? nullptr : conditional; - } else { - // TODO(Frank): It might be faster to "choose" based on values - // and then check whether the resulting tree has non-nullptrs. - std::vector set_diff; - std::set_difference( - discreteProbsKeySet.begin(), discreteProbsKeySet.end(), - hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(), - std::back_inserter(set_diff)); - - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment); - - // If any one of the sub-branches are non-zero, - // we need this conditional. - if (discreteProbs(augmented_values) > 0.0) { - return conditional; - } - } - // If we are here, it means that all the sub-branches are 0, so we prune. - return nullptr; - } + return (max->evaluate(choices) == 0.0) ? nullptr : conditional; }; auto pruned_conditionals = conditionals_.apply(pruner);