move normalization constant to base class
parent
5c59862238
commit
8c752049de
|
|
@ -56,4 +56,27 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
|
||||||
const HybridValues& c) const {
|
const HybridValues& c) const {
|
||||||
throw std::runtime_error("Conditional::evaluate is not implemented");
|
throw std::runtime_error("Conditional::evaluate is not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
|
||||||
|
return std::exp(logNormalizationConstant());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template <class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::checkInvariants(
|
||||||
|
const HybridValues& values) const {
|
||||||
|
const double probability = evaluate(values);
|
||||||
|
if (probability < 0.0 || probability > 1.0)
|
||||||
|
return false; // probability is not in [0,1]
|
||||||
|
const double logProb = logProbability(values);
|
||||||
|
if (std::abs(probability - std::exp(logProb)) > 1e-9)
|
||||||
|
return false; // logProb is not consistent with probability
|
||||||
|
const double expected =
|
||||||
|
this->logNormalizationConstant() - this->error(values);
|
||||||
|
if (std::abs(logProb - expected) > 1e-9)
|
||||||
|
return false; // logProb is not consistent with error
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,15 @@ namespace gtsam {
|
||||||
return evaluate(x);
|
return evaluate(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* By default, log normalization constant = 0.0.
|
||||||
|
* Override if this depends on the parameters.
|
||||||
|
*/
|
||||||
|
virtual double logNormalizationConstant() const;
|
||||||
|
|
||||||
|
/** Non-virtual, exponentiate logNormalizationConstant. */
|
||||||
|
double normalizationConstant() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -172,7 +181,16 @@ namespace gtsam {
|
||||||
/** Mutable iterator pointing past the last parent key. */
|
/** Mutable iterator pointing past the last parent key. */
|
||||||
typename FACTOR::iterator endParents() { return asFactor().end(); }
|
typename FACTOR::iterator endParents() { return asFactor().end(); }
|
||||||
|
|
||||||
|
/** Check that the invariants hold for derived class at a given point. */
|
||||||
|
bool checkInvariants(const HybridValues& values) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
/// @name Serialization
|
||||||
|
/// @{
|
||||||
|
|
||||||
// Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type)
|
// Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type)
|
||||||
FACTOR& asFactor() { return static_cast<FACTOR&>(static_cast<DERIVEDCONDITIONAL&>(*this)); }
|
FACTOR& asFactor() { return static_cast<FACTOR&>(static_cast<DERIVEDCONDITIONAL&>(*this)); }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,14 +136,7 @@ namespace gtsam {
|
||||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
* log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma)
|
||||||
*/
|
*/
|
||||||
double logNormalizationConstant() const;
|
double logNormalizationConstant() const override;
|
||||||
|
|
||||||
/**
|
|
||||||
* normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
|
||||||
*/
|
|
||||||
inline double normalizationConstant() const {
|
|
||||||
return exp(logNormalizationConstant());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate log-probability log(evaluate(x)) for given values `x`:
|
* Calculate log-probability log(evaluate(x)) for given values `x`:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue