Check for error>0 and proper normalization constant

release/4.3a0
Frank Dellaert 2023-01-16 15:32:34 -08:00
parent c22b2cad3b
commit 5b0408c7bb
2 changed files with 20 additions and 5 deletions

View File

@ -57,6 +57,14 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
throw std::runtime_error("Conditional::evaluate is not implemented"); throw std::runtime_error("Conditional::evaluate is not implemented");
} }
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logNormalizationConstant()
const {
throw std::runtime_error(
"Conditional::logNormalizationConstant is not implemented");
}
/* ************************************************************************* */ /* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL> template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const { double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
@ -75,8 +83,13 @@ bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
const double logProb = conditional.logProbability(values); const double logProb = conditional.logProbability(values);
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9) if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
return false; // logProb is not consistent with prob_or_density return false; // logProb is not consistent with prob_or_density
const double expected = if (std::abs(conditional.logNormalizationConstant() -
conditional.logNormalizationConstant() - conditional.error(values); std::log(conditional.normalizationConstant())) > 1e-9)
return false; // log normalization constant is not consistent with
// normalization constant
const double error = conditional.error(values);
if (error < 0.0) return false; // prob_or_density is negative.
const double expected = conditional.logNormalizationConstant() - error;
if (std::abs(logProb - expected) > 1e-9) if (std::abs(logProb - expected) > 1e-9)
return false; // logProb is not consistent with error return false; // logProb is not consistent with error
return true; return true;

View File

@ -144,10 +144,10 @@ namespace gtsam {
} }
/** /**
* By default, log normalization constant = 0.0. * All conditional types need to implement a log normalization constant to
* Override if this depends on the parameters. * make it such that error>=0.
*/ */
virtual double logNormalizationConstant() const { return 0.0; } virtual double logNormalizationConstant() const;
/** Non-virtual, exponentiate logNormalizationConstant. */ /** Non-virtual, exponentiate logNormalizationConstant. */
double normalizationConstant() const; double normalizationConstant() const;
@ -189,6 +189,8 @@ namespace gtsam {
* - evaluate >= 0.0 * - evaluate >= 0.0
* - evaluate(x) == conditional(x) * - evaluate(x) == conditional(x)
* - exp(logProbability(x)) == evaluate(x) * - exp(logProbability(x)) == evaluate(x)
* - logNormalizationConstant() = log(normalizationConstant())
* - error >= 0.0
* - logProbability(x) == logNormalizationConstant() - error(x) * - logProbability(x) == logNormalizationConstant() - error(x)
* *
* @param conditional The conditional to test, as a reference to the derived type. * @param conditional The conditional to test, as a reference to the derived type.