diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index fd437f52c..881a97a1b 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -38,15 +38,6 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { return e != nullptr && Base::equals(*e, tol); } -/* *******************************************************************************/ -GaussianMixtureFactor GaussianMixtureFactor::FromFactors( - const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors) { - Factors dt(discreteKeys, factors); - - return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); -} - /* *******************************************************************************/ void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 0b65b5aa9..b8f475de3 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -93,19 +93,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Construct a new GaussianMixtureFactor object using a vector of * GaussianFactor shared pointers. * - * @param keys Vector of keys for continuous factors. + * @param continuousKeys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. * @param factors Vector of gaussian factor shared pointers. */ - GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys, + GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, const std::vector &factors) - : GaussianMixtureFactor(keys, discreteKeys, + : GaussianMixtureFactor(continuousKeys, discreteKeys, Factors(discreteKeys, factors)) {} - static This FromFactors( - const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, - const std::vector &factors); - /// @} /// @name Testable /// @{ diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index ab070c68f..2721612f9 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -54,7 +54,7 @@ virtual class HybridDiscreteFactor { #include class GaussianMixtureFactor : gtsam::HybridFactor { - static GaussianMixtureFactor FromFactors( + GaussianMixtureFactor( const gtsam::KeyVector& continuousKeys, const gtsam::DiscreteKeys& discreteKeys, const std::vector& factorsList); diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index f9e1916d0..59c57f8a0 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -57,7 +57,7 @@ inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain( // keyFunc(1) to keyFunc(n+1) for (size_t t = 1; t < n; t++) { - hfg.add(GaussianMixtureFactor::FromFactors( + hfg.add(GaussianMixtureFactor( {keyFunc(t), keyFunc(t + 1)}, {{dKeyFunc(t), 2}}, {boost::make_shared(keyFunc(t), I_3x3, keyFunc(t + 1), I_3x3, Z_3x1), diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 310081f02..fe6a57dee 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -20,6 +20,8 @@ #include #include +#include +#include #include #include @@ -33,6 +35,7 @@ using namespace gtsam; using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; +using symbol_shorthand::Z; /* ************************************************************************* */ /* Check construction of GaussianMixture P(x1 | x2, m1) as well as accessing a @@ -127,7 +130,43 @@ TEST(GaussianMixture, Error) { assignment[M(1)] = 0; EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); assignment[M(1)] = 1; - EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8); + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), + 1e-8); +} + +/* ************************************************************************* */ +// Create a likelihood factor for a Gaussian mixture, return a Mixture factor on +// the parents. +GaussianMixtureFactor::shared_ptr likelihood(const HybridValues& values) { + GaussianMixtureFactor::shared_ptr factor; + return factor; +} + +/// Check that likelihood returns a mixture factor on the parents. +TEST(GaussianMixture, Likelihood) { + // Create mode key: 0 is low-noise, 1 is high-noise. + Key modeKey = M(0); + DiscreteKey mode(modeKey, 2); + + // Create Gaussian mixture Z(0) = X(0) + noise. + // TODO(dellaert): making copies below is not ideal ! + Matrix1 I = Matrix1::Identity(); + const auto conditional0 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const auto conditional1 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); + const auto gm = GaussianMixture::FromConditionals( + {Z(0)}, {X(0)}, {mode}, {conditional0, conditional1}); + + // Call the likelihood function: + VectorValues measurements; + measurements.insert(Z(0), Vector1(0)); + HybridValues values(DiscreteValues(), measurements); + const auto factor = likelihood(values); + + // Check that the factor is a mixture factor on the parents. + const GaussianMixtureFactor expected = GaussianMixtureFactor(); + EXPECT(assert_equal(*factor, expected)); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 55e4c28ad..f774e8ef1 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -176,7 +176,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); - hfg.add(GaussianMixtureFactor::FromFactors( + hfg.add(GaussianMixtureFactor( {X(1)}, {{M(1), 2}}, {boost::make_shared(X(1), I_3x3, Z_3x1), boost::make_shared(X(1), I_3x3, Vector3::Ones())})); @@ -235,7 +235,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1)); { - hfg.add(GaussianMixtureFactor::FromFactors( + hfg.add(GaussianMixtureFactor( {X(0)}, {{M(0), 2}}, {boost::make_shared(X(0), I_3x3, Z_3x1), boost::make_shared(X(0), I_3x3, Vector3::Ones())}));