diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 24d166095..e686241fc 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -204,8 +204,7 @@ boost::shared_ptr GaussianMixture::likelihood( const GaussianMixtureFactor::Factors likelihoods( conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { return GaussianMixtureFactor::FactorAndConstant{ - conditional->likelihood(given), - conditional->logNormalizationConstant()}; + conditional->likelihood(given), 0.0}; }); return boost::make_shared( continuousParentKeys, discreteParentKeys, likelihoods); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index be9cdba85..fd1d24722 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -341,11 +341,13 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph( // replace it by a likelihood factor: for (auto &&conditional : *this) { if (conditional->frontalsIn(measurements)) { - if (auto gc = conditional->asGaussian()) + if (auto gc = conditional->asGaussian()) { fg.push_back(gc->likelihood(measurements)); - else if (auto gm = conditional->asMixture()) + } else if (auto gm = conditional->asMixture()) { fg.push_back(gm->likelihood(measurements)); - else { + const auto constantsFactor = gm->normalizationConstants(); + if (constantsFactor) fg.push_back(constantsFactor); + } else { throw std::runtime_error("Unknown conditional type"); } } else { diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 5bad40728..024aafbc7 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -198,8 +198,7 @@ TEST(GaussianMixture, Likelihood) { gm.conditionals(), [measurements](const GaussianConditional::shared_ptr& conditional) { return GaussianMixtureFactor::FactorAndConstant{ - conditional->likelihood(measurements), - conditional->logNormalizationConstant()}; + conditional->likelihood(measurements), 0.0}; }); const GaussianMixtureFactor expected({X(0)}, {mode}, factors); EXPECT(assert_equal(expected, *factor)); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index f904ee5ba..ef89c0bfd 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -613,7 +613,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) { const int num_measurements = 1; auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); - EXPECT_LONGS_EQUAL(3, fg.size()); + EXPECT_LONGS_EQUAL(4, fg.size()); // Assemble graph tree: auto actual = fg.assembleGraphTree(); @@ -625,7 +625,7 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) { CHECK(mixture); // Get prior factor: - const auto gf = boost::dynamic_pointer_cast(fg.at(1)); + const auto gf = boost::dynamic_pointer_cast(fg.at(2)); CHECK(gf); using GF = GaussianFactor::shared_ptr; const GF prior = gf->asGaussian(); @@ -654,7 +654,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { const int num_measurements = 1; auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); - EXPECT_LONGS_EQUAL(3, fg.size()); + EXPECT_LONGS_EQUAL(4, fg.size()); // Create expected Bayes Net: HybridBayesNet expectedBayesNet; @@ -686,7 +686,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}); - EXPECT_LONGS_EQUAL(4, fg.size()); + EXPECT_LONGS_EQUAL(6, fg.size()); // Create expected Bayes Net: HybridBayesNet expectedBayesNet; @@ -721,7 +721,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) { auto bn = tiny::createHybridBayesNet(num_measurements, manyModes); const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; auto fg = bn.toFactorGraph(measurements); - EXPECT_LONGS_EQUAL(5, fg.size()); + EXPECT_LONGS_EQUAL(7, fg.size()); // Test elimination const auto posterior = fg.eliminateSequential();