diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index b764dc9e0..bf11a50fc 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -279,21 +279,37 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { using Result = std::pair, GaussianMixtureFactor::sharedFactor>; -// Integrate the probability mass in the last continuous conditional using -// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. -// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) +/** + * Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) + * from the residual error at the mean μ. + * The residual error contains no keys, and only + * depends on the discrete separator if present. + */ static std::shared_ptr createDiscreteFactor( const DecisionTree &eliminationResults, const DiscreteKeys &discreteSeparator) { - auto probability = [&](const Result &pair) -> double { + auto logProbability = [&](const Result &pair) -> double { const auto &[conditional, factor] = pair; static const VectorValues kEmpty; // If the factor is not null, it has no keys, just contains the residual. if (!factor) return 1.0; // TODO(dellaert): not loving this. - return exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); + + // Logspace version of: + // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); + // We take negative of the logNormalizationConstant `log(1/k)` + // to get `log(k)`. + return -factor->error(kEmpty) + (-conditional->logNormalizationConstant()); }; - DecisionTree probabilities(eliminationResults, probability); + AlgebraicDecisionTree logProbabilities( + DecisionTree(eliminationResults, logProbability)); + + // Perform normalization + double max_log = logProbabilities.max(); + AlgebraicDecisionTree probabilities = DecisionTree( + logProbabilities, + [&max_log](const double x) { return exp(x - max_log); }); + probabilities = probabilities.normalize(probabilities.sum()); return std::make_shared(discreteSeparator, probabilities); }