Check for error>0 and proper normalization constant

release/4.3a0
Frank Dellaert 2023-01-16 15:32:34 -08:00
parent c22b2cad3b
commit 5b0408c7bb
2 changed files with 20 additions and 5 deletions

View File

@ -57,6 +57,14 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
throw std::runtime_error("Conditional::evaluate is not implemented");
}
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logNormalizationConstant()
const {
throw std::runtime_error(
"Conditional::logNormalizationConstant is not implemented");
}
/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
@ -75,8 +83,13 @@ bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
const double logProb = conditional.logProbability(values);
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
return false; // logProb is not consistent with prob_or_density
const double expected =
conditional.logNormalizationConstant() - conditional.error(values);
if (std::abs(conditional.logNormalizationConstant() -
std::log(conditional.normalizationConstant())) > 1e-9)
return false; // log normalization constant is not consistent with
// normalization constant
const double error = conditional.error(values);
if (error < 0.0) return false; // prob_or_density is negative.
const double expected = conditional.logNormalizationConstant() - error;
if (std::abs(logProb - expected) > 1e-9)
return false; // logProb is not consistent with error
return true;

View File

@ -144,10 +144,10 @@ namespace gtsam {
}
/**
* By default, log normalization constant = 0.0.
* Override if this depends on the parameters.
* All conditional types need to implement a log normalization constant to
* make it such that error>=0.
*/
virtual double logNormalizationConstant() const { return 0.0; }
virtual double logNormalizationConstant() const;
/** Non-virtual, exponentiate logNormalizationConstant. */
double normalizationConstant() const;
@ -189,6 +189,8 @@ namespace gtsam {
* - evaluate >= 0.0
* - evaluate(x) == conditional(x)
* - exp(logProbability(x)) == evaluate(x)
* - logNormalizationConstant() = log(normalizationConstant())
* - error >= 0.0
* - logProbability(x) == logNormalizationConstant() - error(x)
*
* @param conditional The conditional to test, as a reference to the derived type.