diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 9de8aba59..24d166095 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -170,21 +170,41 @@ KeyVector GaussianMixture::continuousParents() const { } /* ************************************************************************* */ -boost::shared_ptr GaussianMixture::likelihood( - const VectorValues &frontals) const { - // Check that values has all frontals - for (auto &&kv : frontals) { - if (frontals.find(kv.first) == frontals.end()) { - throw std::runtime_error("GaussianMixture: frontals missing factor key."); +boost::shared_ptr GaussianMixture::normalizationConstants() + const { + DecisionTree constants( + conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { + return conditional->normalizationConstant(); + }); + // If all constants the same, return nullptr: + if (constants.nrLeaves() == 1) return nullptr; + return boost::make_shared(discreteKeys(), constants); +} + +/* ************************************************************************* */ +bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const { + for (auto &&kv : given) { + if (given.find(kv.first) == given.end()) { + return false; } } + return true; +} + +/* ************************************************************************* */ +boost::shared_ptr GaussianMixture::likelihood( + const VectorValues &given) const { + if (!allFrontalsGiven(given)) { + throw std::runtime_error( + "GaussianMixture::likelihood: given values are missing some frontals."); + } const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { return GaussianMixtureFactor::FactorAndConstant{ - conditional->likelihood(frontals), + conditional->likelihood(given), conditional->logNormalizationConstant()}; }); return boost::make_shared( @@ -285,8 +305,7 @@ AlgebraicDecisionTree GaussianMixture::logProbability( return 1e50; } }; - DecisionTree errorTree(conditionals_, errorFunc); - return errorTree; + return DecisionTree(conditionals_, errorFunc); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index d90e08409..9504f7ffa 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -155,10 +155,16 @@ class GTSAM_EXPORT GaussianMixture /// Returns the continuous keys among the parents. KeyVector continuousParents() const; - // Create a likelihood factor for a Gaussian mixture, return a Mixture factor - // on the parents. + /// Return a discrete factor with possibly varying normalization constants. + /// If there is no variation, return nullptr. + boost::shared_ptr normalizationConstants() const; + + /** + * Create a likelihood factor for a Gaussian mixture, return a Mixture factor + * on the parents. + */ boost::shared_ptr likelihood( - const VectorValues &frontals) const; + const VectorValues &given) const; /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals() const; @@ -233,6 +239,9 @@ class GTSAM_EXPORT GaussianMixture /// @} private: + /// Check whether `given` has values for all frontal keys. + bool allFrontalsGiven(const VectorValues &given) const; + /** Serialization function */ friend class boost::serialization::access; template diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index a2ee2c21f..5bad40728 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -106,13 +106,16 @@ TEST(GaussianMixture, Error) { conditional1 = boost::make_shared(X(1), d2, R2, X(2), S2, model); - // Create decision tree + // Create Gaussian Mixture. DiscreteKey m1(M(1), 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); + // Check that normalizationConstants returns nullptr, as all constants equal. + CHECK(!mixture.normalizationConstants()); + VectorValues values; values.insert(X(1), Vector2::Ones()); values.insert(X(2), Vector2::Zero()); @@ -163,6 +166,19 @@ TEST(GaussianMixture, ContinuousParents) { EXPECT(continuousParentKeys[0] == X(0)); } +/* ************************************************************************* */ +/// Check we can create a DecisionTreeFactor with all normalization constants. +TEST(GaussianMixture, NormalizationConstants) { + const GaussianMixture gm = createSimpleGaussianMixture(); + + const auto factor = gm.normalizationConstants(); + + // Test with 1D Gaussian normalization constants for sigma 0.5 and 3: + auto c = [](double sigma) { return 1.0 / (sqrt(2 * M_PI) * sigma); }; + const DecisionTreeFactor expected({M(0), 2}, {c(0.5), c(3)}); + EXPECT(assert_equal(expected, *factor)); +} + /* ************************************************************************* */ /// Check that likelihood returns a mixture factor on the parents. TEST(GaussianMixture, Likelihood) { @@ -186,7 +202,7 @@ TEST(GaussianMixture, Likelihood) { conditional->logNormalizationConstant()}; }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); - EXPECT(assert_equal(*factor, expected)); + EXPECT(assert_equal(expected, *factor)); } /* ************************************************************************* */