diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index cb8ceed20..bbf739c65 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -233,6 +233,26 @@ continuousElimination(const HybridGaussianFactorGraph &factors, return {std::make_shared(result.first), result.second}; } +/* ************************************************************************ */ +/** + * @brief Exponential log-probabilities after performing + * the necessary normalizations. + * + * @param logProbabilities DecisionTree of log-probabilities. + * @return AlgebraicDecisionTree + */ +static AlgebraicDecisionTree exponentiateLogProbabilities( + const AlgebraicDecisionTree &logProbabilities) { + // Perform normalization + double max_log = logProbabilities.max(); + AlgebraicDecisionTree probabilities = DecisionTree( + logProbabilities, + [&max_log](const double x) { return exp(x - max_log); }); + probabilities = probabilities.normalize(probabilities.sum()); + + return probabilities; +} + /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, @@ -245,14 +265,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a GaussianMixtureFactor with no continuous keys. // In this case, compute discrete probabilities. - auto probability = + auto logProbability = [&](const GaussianFactor::shared_ptr &factor) -> double { if (!factor) return 0.0; - return exp(-factor->error(VectorValues())); + return -factor->error(VectorValues()); }; - dfg.emplace_shared( - gmf->discreteKeys(), - DecisionTree(gmf->factors(), probability)); + AlgebraicDecisionTree logProbabilities = + DecisionTree(gmf->factors(), logProbability); + + AlgebraicDecisionTree probabilities = + exponentiateLogProbabilities(logProbabilities); + dfg.emplace_shared(gmf->discreteKeys(), + probabilities); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. @@ -315,13 +339,8 @@ static std::shared_ptr createDiscreteFactor( AlgebraicDecisionTree logProbabilities( DecisionTree(eliminationResults, logProbability)); - - // Perform normalization - double max_log = logProbabilities.max(); - AlgebraicDecisionTree probabilities = DecisionTree( - logProbabilities, - [&max_log](const double x) { return exp(x - max_log); }); - probabilities = probabilities.normalize(probabilities.sum()); + AlgebraicDecisionTree probabilities = + exponentiateLogProbabilities(logProbabilities); return std::make_shared(discreteSeparator, probabilities); }