From 36604297d7db6f8a0c299f782a157c72f1a70cf9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 18 Dec 2023 14:25:19 -0500 Subject: [PATCH] handle numerical instability --- gtsam/discrete/AlgebraicDecisionTree.h | 3 ++- gtsam/hybrid/HybridBayesNet.cpp | 28 ++++++++++----------- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 7 +++--- gtsam/hybrid/tests/testHybridEstimation.cpp | 1 - gtsam/linear/GaussianConditional.cpp | 6 +++-- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 1a6358680..17385a975 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -225,7 +225,8 @@ namespace gtsam { /// Find the maximum values amongst all leaves double max() const { - double max = std::numeric_limits::min(); + // Get the most negative value + double max = -std::numeric_limits::max(); auto visitor = [&](double x) { max = x > max ? x : max; }; this->visit(visitor); return max; diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e48b3faf7..2b0b11e36 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -274,26 +274,26 @@ HybridValues HybridBayesNet::optimize() const { q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k)) + thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) = exp(log(q) - log(k)) = exp(-error - log(k)) = exp(-(error + log(k))) - So let's compute (error + log(k)) and exponentiate later + So we compute (error + log(k)) and exponentiate later */ - error = error + gm->error(continuousValues); - // Add the logNormalization constant to the error + // Add the error and the logNormalization constant to the error + auto err = gm->error(continuousValues) + gm->logNormalizationConstant(); + // Also compute the sum for discrete probability normalization // (normalization trick for numerical stability) double sum = 0.0; - auto addConstant = [&gm, &sum](const double &error) { - double e = error + gm->logNormalizationConstant(); + auto absSum = [&sum](const double &e) { sum += std::abs(e); return e; }; - error = error.apply(addConstant); - // Normalize by the sum - error = error.normalize(sum); + err.visit(absSum); + // Normalize by the sum to prevent overflow + error = error + err.normalize(sum); // Include the discrete keys std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), @@ -302,11 +302,11 @@ HybridValues HybridBayesNet::optimize() const { } } - double min_log = error.min(); - AlgebraicDecisionTree model_selection = - DecisionTree(error, [&min_log](const double &x) { - return std::exp(-(x - min_log)) * exp(-min_log); - }); + error = error * -1; + double max_log = error.max(); + AlgebraicDecisionTree model_selection = DecisionTree( + error, [&max_log](const double &x) { return std::exp(x - max_log); }); + model_selection = model_selection.normalize(model_selection.sum()); // Only add model_selection if we have discrete keys if (discreteKeySet.size() > 0) { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index dfff3d4f3..260dc6bbe 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -328,7 +328,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // The residual error contains no keys, and only depends on the discrete // separator if present. auto logProbability = [&](const Result &pair) -> double { - // auto probability = [&](const Result &pair) -> double { static const VectorValues kEmpty; // If the factor is not null, it has no keys, just contains the residual. const auto &factor = pair.second; @@ -343,9 +342,11 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Perform normalization double max_log = logProbabilities.max(); - DecisionTree probabilities( + AlgebraicDecisionTree probabilities = DecisionTree( logProbabilities, - [&max_log](const double x) { return exp(x - max_log) * exp(max_log); }); + [&max_log](const double x) { return exp(x - max_log); }); + // probabilities.print("", DefaultKeyFormatter); + probabilities = probabilities.normalize(probabilities.sum()); return { std::make_shared(gaussianMixture), diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index b8edc39d8..1cc28b386 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) { for (auto discrete_conditional : *discreteBayesNet) { bayesNet->add(discrete_conditional); } - auto discreteConditional = discreteBayesNet->at(0)->asDiscrete(); HybridValues hybrid_values = bayesNet->optimize(); diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 0112835aa..4ec1d8b95 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -184,8 +184,10 @@ namespace gtsam { double GaussianConditional::logNormalizationConstant() const { constexpr double log2pi = 1.8378770664093454835606594728112; size_t n = d().size(); - // log det(Sigma)) = - 2.0 * logDeterminant() - return - 0.5 * n * log2pi + logDeterminant(); + // Sigma = (R'R)^{-1}, det(Sigma) = det((R'R)^{-1}) = det(R'R)^{-1} + // log det(Sigma) = -log(det(R'R)) = -2*log(det(R)) + // Hence, log det(Sigma)) = - 2.0 * logDeterminant() + return -0.5 * n * log2pi + logDeterminant(); } /* ************************************************************************* */