diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 82d16226a..a5d06f04d 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -128,6 +129,36 @@ void GaussianMixture::print(const std::string &s, }); } +/* ************************************************************************* */ +KeyVector GaussianMixture::continuousParents() const { + // Get all parent keys: + const auto range = parents(); + KeyVector continuousParentKeys(range.begin(), range.end()); + // Loop over all discrete keys: + for (const auto &discreteKey : discreteKeys()) { + const Key key = discreteKey.first; + // remove that key from continuousParentKeys: + continuousParentKeys.erase(std::remove(continuousParentKeys.begin(), + continuousParentKeys.end(), key), + continuousParentKeys.end()); + } + return continuousParentKeys; +} + +/* ************************************************************************* */ +boost::shared_ptr GaussianMixture::likelihood( + const VectorValues &frontals) const { + // TODO(dellaert): check that values has all frontals + const DiscreteKeys discreteParentKeys = discreteKeys(); + const KeyVector continuousParentKeys = continuousParents(); + const GaussianMixtureFactor::Factors likelihoods( + conditionals(), [&](const GaussianConditional::shared_ptr &conditional) { + return conditional->likelihood(frontals); + }); + return boost::make_shared( + continuousParentKeys, discreteParentKeys, likelihoods); +} + /* ************************************************************************* */ std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { std::set s; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a3393dbb0..672a886ad 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -29,6 +29,8 @@ namespace gtsam { +class GaussianMixtureFactor; + /** * @brief A conditional of gaussian mixtures indexed by discrete variables, as * part of a Bayes Network. This is the result of the elimination of a @@ -117,16 +119,6 @@ class GTSAM_EXPORT GaussianMixture const DiscreteKeys &discreteParents, const std::vector &conditionals); - /// @} - /// @name Standard API - /// @{ - - GaussianConditional::shared_ptr operator()( - const DiscreteValues &discreteValues) const; - - /// Returns the total number of continuous components - size_t nrComponents() const; - /// @} /// @name Testable /// @{ @@ -140,6 +132,22 @@ class GTSAM_EXPORT GaussianMixture const KeyFormatter &formatter = DefaultKeyFormatter) const override; /// @} + /// @name Standard API + /// @{ + + GaussianConditional::shared_ptr operator()( + const DiscreteValues &discreteValues) const; + + /// Returns the total number of continuous components + size_t nrComponents() const; + + /// Returns the continuous keys among the parents. + KeyVector continuousParents() const; + + // Create a likelihood factor for a Gaussian mixture, return a Mixture factor + // on the parents. + boost::shared_ptr likelihood( + const VectorValues &frontals) const; /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals() const; @@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture * @return Sum */ Sum add(const Sum &sum) const; + /// @} }; /// Return the DiscreteKey vector as a set. diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 5542d86a9..ed5771770 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -154,53 +154,16 @@ static GaussianMixture createSimpleGaussianMixture() { } /* ************************************************************************* */ -std::set DiscreteKeysAsSet(const DiscreteKeys& dkeys) { - std::set s; - s.insert(dkeys.begin(), dkeys.end()); - return s; -} - -// Get only the continuous parent keys as a KeyVector: -KeyVector continuousParents(const GaussianMixture& gm) { - // Get all parent keys: - const auto range = gm.parents(); - KeyVector continuousParentKeys(range.begin(), range.end()); - // Loop over all discrete keys: - for (const auto& discreteKey : gm.discreteKeys()) { - const Key key = discreteKey.first; - // remove that key from continuousParentKeys: - continuousParentKeys.erase(std::remove(continuousParentKeys.begin(), - continuousParentKeys.end(), key), - continuousParentKeys.end()); - } - return continuousParentKeys; -} - // Create a test for continuousParents. TEST(GaussianMixture, ContinuousParents) { const GaussianMixture gm = createSimpleGaussianMixture(); - const KeyVector continuousParentKeys = continuousParents(gm); + const KeyVector continuousParentKeys = gm.continuousParents(); // Check that the continuous parent keys are correct: EXPECT(continuousParentKeys.size() == 1); EXPECT(continuousParentKeys[0] == X(0)); } /* ************************************************************************* */ -// Create a likelihood factor for a Gaussian mixture, return a Mixture factor. -GaussianMixtureFactor::shared_ptr likelihood(const GaussianMixture& gm, - const VectorValues& frontals) { - // TODO(dellaert): check that values has all frontals - const DiscreteKeys discreteParentKeys = gm.discreteKeys(); - const KeyVector continuousParentKeys = continuousParents(gm); - const GaussianMixtureFactor::Factors likelihoods( - gm.conditionals(), - [&](const GaussianConditional::shared_ptr& conditional) { - return conditional->likelihood(frontals); - }); - return boost::make_shared( - continuousParentKeys, discreteParentKeys, likelihoods); -} - /// Check that likelihood returns a mixture factor on the parents. TEST(GaussianMixture, Likelihood) { const GaussianMixture gm = createSimpleGaussianMixture(); @@ -208,7 +171,7 @@ TEST(GaussianMixture, Likelihood) { // Call the likelihood function: VectorValues measurements; measurements.insert(Z(0), Vector1(0)); - const auto factor = likelihood(gm, measurements); + const auto factor = gm.likelihood(measurements); // Check that the factor is a mixture factor on the parents. // Loop over all discrete assignments over the discrete parents: