diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index f61b280cb..1913be7aa 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -298,9 +298,14 @@ double GaussianMixture::error(const HybridValues &values) const { /* *******************************************************************************/ double GaussianMixture::logProbability(const HybridValues &values) const { - // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); return conditional->logProbability(values.continuous()); } +/* *******************************************************************************/ +double GaussianMixture::evaluate(const HybridValues &values) const { + auto conditional = conditionals_(values.discrete()); + return conditional->evaluate(values.continuous()); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a8d07cbc8..2137acff6 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -175,7 +175,7 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Compute the error of this Gaussian Mixture. - * + * * log(probability(x)) = K - error(x) * * @param values Continuous values and discrete assignment. @@ -191,12 +191,13 @@ class GTSAM_EXPORT GaussianMixture */ double logProbability(const HybridValues &values) const override; - // /// Calculate probability density for given values `x`. - // double evaluate(const HybridValues &values) const; + /// Calculate probability density for given `values`. + double evaluate(const HybridValues &values) const override; - // /// Evaluate probability density, sugar. - // double operator()(const HybridValues &values) const { return - // evaluate(values); } + /// Evaluate probability density, sugar. + double operator()(const HybridValues &values) const { + return evaluate(values); + } /** * @brief Prune the decision tree of Gaussian factors as per the discrete diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 55fd5d5d4..24f61a85f 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -151,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const { "HybridConditional::logProbability: conditional type not handled"); } +/* ************************************************************************ */ +double HybridConditional::logNormalizationConstant() const { + if (auto gc = asGaussian()) { + return gc->logNormalizationConstant(); + } + if (auto gm = asMixture()) { + return gm->logNormalizationConstant(); // 0.0! + } + if (auto dc = asDiscrete()) { + return dc->logNormalizationConstant(); // 0.0! + } + throw std::runtime_error( + "HybridConditional::logProbability: conditional type not handled"); +} + +/* ************************************************************************ */ +double HybridConditional::evaluate(const HybridValues &values) const { + return std::exp(logProbability(values)); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 19c070974..c8cb968df 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -179,9 +179,19 @@ class GTSAM_EXPORT HybridConditional /// Return the error of the underlying conditional. double error(const HybridValues& values) const override; - /// Return the logProbability of the underlying conditional. + /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; + /** + * Return the log normalization constant. + * Note this is 0.0 for discrete and hybrid conditionals, but depends + * on the continuous parameters for Gaussian conditionals. + */ + double logNormalizationConstant() const override; + + /// Return the probability (or density) of the underlying conditional. + double evaluate(const HybridValues& values) const override; + /// Check if VectorValues `measurements` contains all frontal keys. bool frontalsIn(const VectorValues& measurements) const { for (Key key : frontals()) { diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp new file mode 100644 index 000000000..da766a56f --- /dev/null +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -0,0 +1,75 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridConditional.cpp + * @brief Unit tests for HybridConditional class + * @date January 2023 + */ + +#include + +#include "TinyHybridExample.h" + +// Include for test suite +#include + +using namespace gtsam; + +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +/* ****************************************************************************/ +// Check invariants for all conditionals in a tiny Bayes net. +TEST(HybridConditional, Invariants) { + // Create hybrid Bayes net p(z|x,m)p(x)P(m) + auto bn = tiny::createHybridBayesNet(); + + // Create values to check invariants. + const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}}; + const DiscreteValues d{{M(0), 1}}; + const HybridValues values{c, d}; + + // Check invariants for p(z|x,m) + auto hc1 = bn.at(0); + CHECK(hc1->isHybrid()); + GTSAM_PRINT(*hc1); + + // Check invariants as a GaussianMixture. + const auto mixture = hc1->asMixture(); + double probability = mixture->evaluate(values); + CHECK(probability >= 0.0); + EXPECT_DOUBLES_EQUAL(probability, (*mixture)(values), 1e-9); + double logProb = mixture->logProbability(values); + EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); + double expected = + mixture->logNormalizationConstant() - mixture->error(values); + EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + EXPECT(GaussianMixture::CheckInvariants(*mixture, values)); + + // Check invariants as a HybridConditional. + probability = hc1->evaluate(values); + CHECK(probability >= 0.0); + EXPECT_DOUBLES_EQUAL(probability, (*hc1)(values), 1e-9); + logProb = hc1->logProbability(values); + EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); + expected = hc1->logNormalizationConstant() - hc1->error(values); + EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + EXPECT(HybridConditional::CheckInvariants(*hc1, values)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */