diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f2cac485d..e48b3faf7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -303,9 +303,10 @@ HybridValues HybridBayesNet::optimize() const { } double min_log = error.min(); - AlgebraicDecisionTree model_selection = DecisionTree( - error, [&min_log](const double &x) { return std::exp(-(x - min_log)); }); - model_selection = model_selection + exp(-min_log); + AlgebraicDecisionTree model_selection = + DecisionTree(error, [&min_log](const double &x) { + return std::exp(-(x - min_log)) * exp(-min_log); + }); // Only add model_selection if we have discrete keys if (discreteKeySet.size() > 0) { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 467cff710..dfff3d4f3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -323,18 +323,29 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // If there are no more continuous parents, then we create a // DiscreteFactor here, with the error for each discrete choice. - // Integrate the probability mass in the last continuous conditional using - // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. - // discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) - auto probability = [&](const Result &pair) -> double { + // Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) + // from the residual error at the mean μ. + // The residual error contains no keys, and only depends on the discrete + // separator if present. + auto logProbability = [&](const Result &pair) -> double { + // auto probability = [&](const Result &pair) -> double { static const VectorValues kEmpty; // If the factor is not null, it has no keys, just contains the residual. const auto &factor = pair.second; if (!factor) return 1.0; // TODO(dellaert): not loving this. - return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant(); + + // exp(-factor->error(kEmpty)) / pair.first->normalizationConstant(); + return -factor->error(kEmpty) - pair.first->logNormalizationConstant(); }; - DecisionTree probabilities(eliminationResults, probability); + AlgebraicDecisionTree logProbabilities( + DecisionTree(eliminationResults, logProbability)); + + // Perform normalization + double max_log = logProbabilities.max(); + DecisionTree probabilities( + logProbabilities, + [&max_log](const double x) { return exp(x - max_log) * exp(max_log); }); return { std::make_shared(gaussianMixture),