Check for error>0 and proper normalization constant
parent
c22b2cad3b
commit
5b0408c7bb
|
@ -57,6 +57,14 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
||||||
throw std::runtime_error("Conditional::evaluate is not implemented");
|
throw std::runtime_error("Conditional::evaluate is not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::logNormalizationConstant()
|
||||||
|
const {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Conditional::logNormalizationConstant is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class FACTOR, class DERIVEDCONDITIONAL>
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
||||||
|
@ -75,8 +83,13 @@ bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
|
||||||
const double logProb = conditional.logProbability(values);
|
const double logProb = conditional.logProbability(values);
|
||||||
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
|
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
|
||||||
return false; // logProb is not consistent with prob_or_density
|
return false; // logProb is not consistent with prob_or_density
|
||||||
const double expected =
|
if (std::abs(conditional.logNormalizationConstant() -
|
||||||
conditional.logNormalizationConstant() - conditional.error(values);
|
std::log(conditional.normalizationConstant())) > 1e-9)
|
||||||
|
return false; // log normalization constant is not consistent with
|
||||||
|
// normalization constant
|
||||||
|
const double error = conditional.error(values);
|
||||||
|
if (error < 0.0) return false; // prob_or_density is negative.
|
||||||
|
const double expected = conditional.logNormalizationConstant() - error;
|
||||||
if (std::abs(logProb - expected) > 1e-9)
|
if (std::abs(logProb - expected) > 1e-9)
|
||||||
return false; // logProb is not consistent with error
|
return false; // logProb is not consistent with error
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -144,10 +144,10 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* By default, log normalization constant = 0.0.
|
* All conditional types need to implement a log normalization constant to
|
||||||
* Override if this depends on the parameters.
|
* make it such that error>=0.
|
||||||
*/
|
*/
|
||||||
virtual double logNormalizationConstant() const { return 0.0; }
|
virtual double logNormalizationConstant() const;
|
||||||
|
|
||||||
/** Non-virtual, exponentiate logNormalizationConstant. */
|
/** Non-virtual, exponentiate logNormalizationConstant. */
|
||||||
double normalizationConstant() const;
|
double normalizationConstant() const;
|
||||||
|
@ -189,6 +189,8 @@ namespace gtsam {
|
||||||
* - evaluate >= 0.0
|
* - evaluate >= 0.0
|
||||||
* - evaluate(x) == conditional(x)
|
* - evaluate(x) == conditional(x)
|
||||||
* - exp(logProbability(x)) == evaluate(x)
|
* - exp(logProbability(x)) == evaluate(x)
|
||||||
|
* - logNormalizationConstant() = log(normalizationConstant())
|
||||||
|
* - error >= 0.0
|
||||||
* - logProbability(x) == logNormalizationConstant() - error(x)
|
* - logProbability(x) == logNormalizationConstant() - error(x)
|
||||||
*
|
*
|
||||||
* @param conditional The conditional to test, as a reference to the derived type.
|
* @param conditional The conditional to test, as a reference to the derived type.
|
||||||
|
|
Loading…
Reference in New Issue