handle numerical instability

release/4.3a0
Varun Agrawal 2023-12-18 14:25:19 -05:00
parent 6f09be51cb
commit 36604297d7
5 changed files with 24 additions and 21 deletions

View File

@ -225,7 +225,8 @@ namespace gtsam {
/// Find the maximum values amongst all leaves
double max() const {
double max = std::numeric_limits<double>::min();
// Get the most negative value
double max = -std::numeric_limits<double>::max();
auto visitor = [&](double x) { max = x > max ? x : max; };
this->visit(visitor);
return max;

View File

@ -274,26 +274,26 @@ HybridValues HybridBayesNet::optimize() const {
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k)))
So let's compute (error + log(k)) and exponentiate later
So we compute (error + log(k)) and exponentiate later
*/
error = error + gm->error(continuousValues);
// Add the logNormalization constant to the error
// Add the error and the logNormalization constant to the error
auto err = gm->error(continuousValues) + gm->logNormalizationConstant();
// Also compute the sum for discrete probability normalization
// (normalization trick for numerical stability)
double sum = 0.0;
auto addConstant = [&gm, &sum](const double &error) {
double e = error + gm->logNormalizationConstant();
auto absSum = [&sum](const double &e) {
sum += std::abs(e);
return e;
};
error = error.apply(addConstant);
// Normalize by the sum
error = error.normalize(sum);
err.visit(absSum);
// Normalize by the sum to prevent overflow
error = error + err.normalize(sum);
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
@ -302,11 +302,11 @@ HybridValues HybridBayesNet::optimize() const {
}
}
double min_log = error.min();
AlgebraicDecisionTree<Key> model_selection =
DecisionTree<Key, double>(error, [&min_log](const double &x) {
return std::exp(-(x - min_log)) * exp(-min_log);
});
error = error * -1;
double max_log = error.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
error, [&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) {

View File

@ -328,7 +328,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// The residual error contains no keys, and only depends on the discrete
// separator if present.
auto logProbability = [&](const Result &pair) -> double {
// auto probability = [&](const Result &pair) -> double {
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
const auto &factor = pair.second;
@ -343,9 +342,11 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Perform normalization
double max_log = logProbabilities.max();
DecisionTree<Key, double> probabilities(
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log) * exp(max_log); });
[&max_log](const double x) { return exp(x - max_log); });
// probabilities.print("", DefaultKeyFormatter);
probabilities = probabilities.normalize(probabilities.sum());
return {
std::make_shared<HybridConditional>(gaussianMixture),

View File

@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) {
for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional);
}
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();
HybridValues hybrid_values = bayesNet->optimize();

View File

@ -184,8 +184,10 @@ namespace gtsam {
double GaussianConditional::logNormalizationConstant() const {
constexpr double log2pi = 1.8378770664093454835606594728112;
size_t n = d().size();
// log det(Sigma)) = - 2.0 * logDeterminant()
return - 0.5 * n * log2pi + logDeterminant();
// Sigma = (R'R)^{-1}, det(Sigma) = det((R'R)^{-1}) = det(R'R)^{-1}
// log det(Sigma) = -log(det(R'R)) = -2*log(det(R))
// Hence, log det(Sigma)) = - 2.0 * logDeterminant()
return -0.5 * n * log2pi + logDeterminant();
}
/* ************************************************************************* */