diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 4aa9c5126..9377e8cc4 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -57,6 +57,14 @@ double Conditional::evaluate( throw std::runtime_error("Conditional::evaluate is not implemented"); } +/* ************************************************************************* */ +template +double Conditional::logNormalizationConstant() + const { + throw std::runtime_error( + "Conditional::logNormalizationConstant is not implemented"); +} + /* ************************************************************************* */ template double Conditional::normalizationConstant() const { @@ -75,8 +83,13 @@ bool Conditional::CheckInvariants( const double logProb = conditional.logProbability(values); if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9) return false; // logProb is not consistent with prob_or_density - const double expected = - conditional.logNormalizationConstant() - conditional.error(values); + if (std::abs(conditional.logNormalizationConstant() - + std::log(conditional.normalizationConstant())) > 1e-9) + return false; // log normalization constant is not consistent with + // normalization constant + const double error = conditional.error(values); + if (error < 0.0) return false; // prob_or_density is negative. + const double expected = conditional.logNormalizationConstant() - error; if (std::abs(logProb - expected) > 1e-9) return false; // logProb is not consistent with error return true; diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 351d2d4a4..ba7b6897e 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -144,10 +144,10 @@ namespace gtsam { } /** - * By default, log normalization constant = 0.0. - * Override if this depends on the parameters. + * All conditional types need to implement a log normalization constant to + * make it such that error>=0. */ - virtual double logNormalizationConstant() const { return 0.0; } + virtual double logNormalizationConstant() const; /** Non-virtual, exponentiate logNormalizationConstant. */ double normalizationConstant() const; @@ -189,6 +189,8 @@ namespace gtsam { * - evaluate >= 0.0 * - evaluate(x) == conditional(x) * - exp(logProbability(x)) == evaluate(x) + * - logNormalizationConstant() = log(normalizationConstant()) + * - error >= 0.0 * - logProbability(x) == logNormalizationConstant() - error(x) * * @param conditional The conditional to test, as a reference to the derived type.