From 611f61c7f4189c668326073f66185e47cd9b2fc3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 29 Dec 2022 13:21:20 -0500 Subject: [PATCH] proto code for likelihood --- gtsam/hybrid/tests/testGaussianMixture.cpp | 85 ++++++++++++++++++---- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index fe6a57dee..5542d86a9 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -135,19 +135,12 @@ TEST(GaussianMixture, Error) { } /* ************************************************************************* */ -// Create a likelihood factor for a Gaussian mixture, return a Mixture factor on -// the parents. -GaussianMixtureFactor::shared_ptr likelihood(const HybridValues& values) { - GaussianMixtureFactor::shared_ptr factor; - return factor; -} - -/// Check that likelihood returns a mixture factor on the parents. -TEST(GaussianMixture, Likelihood) { - // Create mode key: 0 is low-noise, 1 is high-noise. - Key modeKey = M(0); - DiscreteKey mode(modeKey, 2); +// Create mode key: 0 is low-noise, 1 is high-noise. +static const Key modeKey = M(0); +static const DiscreteKey mode(modeKey, 2); +// Create a simple GaussianMixture +static GaussianMixture createSimpleGaussianMixture() { // Create Gaussian mixture Z(0) = X(0) + noise. // TODO(dellaert): making copies below is not ideal ! Matrix1 I = Matrix1::Identity(); @@ -157,15 +150,77 @@ TEST(GaussianMixture, Likelihood) { GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); const auto gm = GaussianMixture::FromConditionals( {Z(0)}, {X(0)}, {mode}, {conditional0, conditional1}); + return gm; +} + +/* ************************************************************************* */ +std::set DiscreteKeysAsSet(const DiscreteKeys& dkeys) { + std::set s; + s.insert(dkeys.begin(), dkeys.end()); + return s; +} + +// Get only the continuous parent keys as a KeyVector: +KeyVector continuousParents(const GaussianMixture& gm) { + // Get all parent keys: + const auto range = gm.parents(); + KeyVector continuousParentKeys(range.begin(), range.end()); + // Loop over all discrete keys: + for (const auto& discreteKey : gm.discreteKeys()) { + const Key key = discreteKey.first; + // remove that key from continuousParentKeys: + continuousParentKeys.erase(std::remove(continuousParentKeys.begin(), + continuousParentKeys.end(), key), + continuousParentKeys.end()); + } + return continuousParentKeys; +} + +// Create a test for continuousParents. +TEST(GaussianMixture, ContinuousParents) { + const GaussianMixture gm = createSimpleGaussianMixture(); + const KeyVector continuousParentKeys = continuousParents(gm); + // Check that the continuous parent keys are correct: + EXPECT(continuousParentKeys.size() == 1); + EXPECT(continuousParentKeys[0] == X(0)); +} + +/* ************************************************************************* */ +// Create a likelihood factor for a Gaussian mixture, return a Mixture factor. +GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm, + const VectorValues& frontals) { + // TODO(dellaert): check that values has all frontals + const DiscreteKeys discreteParentKeys = gm.discreteKeys(); + const KeyVector continuousParentKeys = continuousParents(gm); + const GaussianMixtureFactor::Factors likelihoods( + gm.conditionals(), + [&](const GaussianConditional::shared_ptr& conditional) { + return conditional->likelihood(frontals); + }); + return boost::make_shared( + continuousParentKeys, discreteParentKeys, likelihoods); +} + +/// Check that likelihood returns a mixture factor on the parents. +TEST(GaussianMixture, Likelihood) { + const GaussianMixture gm = createSimpleGaussianMixture(); // Call the likelihood function: VectorValues measurements; measurements.insert(Z(0), Vector1(0)); - HybridValues values(DiscreteValues(), measurements); - const auto factor = likelihood(values); + const auto factor = likelihood(gm, measurements); // Check that the factor is a mixture factor on the parents. - const GaussianMixtureFactor expected = GaussianMixtureFactor(); + // Loop over all discrete assignments over the discrete parents: + const DiscreteKeys discreteParentKeys = gm.discreteKeys(); + + // Apply the likelihood function to all conditionals: + const GaussianMixtureFactor::Factors factors( + gm.conditionals(), + [measurements](const GaussianConditional::shared_ptr& conditional) { + return conditional->likelihood(measurements); + }); + const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(*factor, expected)); }