Check for error>0 and proper normalization constant
parent
c22b2cad3b
commit
5b0408c7bb
|
@ -57,6 +57,14 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
|||
throw std::runtime_error("Conditional::evaluate is not implemented");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logNormalizationConstant()
|
||||
const {
|
||||
throw std::runtime_error(
|
||||
"Conditional::logNormalizationConstant is not implemented");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
||||
|
@ -75,8 +83,13 @@ bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
|
|||
const double logProb = conditional.logProbability(values);
|
||||
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
|
||||
return false; // logProb is not consistent with prob_or_density
|
||||
const double expected =
|
||||
conditional.logNormalizationConstant() - conditional.error(values);
|
||||
if (std::abs(conditional.logNormalizationConstant() -
|
||||
std::log(conditional.normalizationConstant())) > 1e-9)
|
||||
return false; // log normalization constant is not consistent with
|
||||
// normalization constant
|
||||
const double error = conditional.error(values);
|
||||
if (error < 0.0) return false; // prob_or_density is negative.
|
||||
const double expected = conditional.logNormalizationConstant() - error;
|
||||
if (std::abs(logProb - expected) > 1e-9)
|
||||
return false; // logProb is not consistent with error
|
||||
return true;
|
||||
|
|
|
@ -144,10 +144,10 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/**
|
||||
* By default, log normalization constant = 0.0.
|
||||
* Override if this depends on the parameters.
|
||||
* All conditional types need to implement a log normalization constant to
|
||||
* make it such that error>=0.
|
||||
*/
|
||||
virtual double logNormalizationConstant() const { return 0.0; }
|
||||
virtual double logNormalizationConstant() const;
|
||||
|
||||
/** Non-virtual, exponentiate logNormalizationConstant. */
|
||||
double normalizationConstant() const;
|
||||
|
@ -189,6 +189,8 @@ namespace gtsam {
|
|||
* - evaluate >= 0.0
|
||||
* - evaluate(x) == conditional(x)
|
||||
* - exp(logProbability(x)) == evaluate(x)
|
||||
* - logNormalizationConstant() = log(normalizationConstant())
|
||||
* - error >= 0.0
|
||||
* - logProbability(x) == logNormalizationConstant() - error(x)
|
||||
*
|
||||
* @param conditional The conditional to test, as a reference to the derived type.
|
||||
|
|
Loading…
Reference in New Issue