diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index ee13946d9..5a17c44cc 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -56,4 +56,27 @@ double Conditional::evaluate( const HybridValues& c) const { throw std::runtime_error("Conditional::evaluate is not implemented"); } + +/* ************************************************************************* */ +template +double Conditional::normalizationConstant() const { + return std::exp(logNormalizationConstant()); +} + +/* ************************************************************************* */ +template +bool Conditional::checkInvariants( + const HybridValues& values) const { + const double probability = evaluate(values); + if (probability < 0.0 || probability > 1.0) + return false; // probability is not in [0,1] + const double logProb = logProbability(values); + if (std::abs(probability - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with probability + const double expected = + this->logNormalizationConstant() - this->error(values); + if (std::abs(logProb - expected) > 1e-9) + return false; // logProb is not consistent with error +} + } // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 9083c5c1a..bb75f9c6e 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -141,6 +141,15 @@ namespace gtsam { return evaluate(x); } + /** + * By default, log normalization constant = 0.0. + * Override if this depends on the parameters. + */ + virtual double logNormalizationConstant() const; + + /** Non-virtual, exponentiate logNormalizationConstant. */ + double normalizationConstant() const; + /// @} /// @name Advanced Interface /// @{ @@ -172,7 +181,16 @@ namespace gtsam { /** Mutable iterator pointing past the last parent key. */ typename FACTOR::iterator endParents() { return asFactor().end(); } + /** Check that the invariants hold for derived class at a given point. */ + bool checkInvariants(const HybridValues& values) const; + + /// @} + private: + + /// @name Serialization + /// @{ + // Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type) FACTOR& asFactor() { return static_cast(static_cast(*this)); } diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 880d13064..69e2ef2d3 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -136,14 +136,7 @@ namespace gtsam { * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) */ - double logNormalizationConstant() const; - - /** - * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) - */ - inline double normalizationConstant() const { - return exp(logNormalizationConstant()); - } + double logNormalizationConstant() const override; /** * Calculate log-probability log(evaluate(x)) for given values `x`: