From 5e1931eb98c4f299b96ef46ce12e0e75e781bb37 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 31 Dec 2024 15:54:07 -0500 Subject: [PATCH] update testGaussianMixture --- gtsam/hybrid/tests/testGaussianMixture.cpp | 27 ++++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 14bef5fbb..2de8d15ec 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -79,8 +80,9 @@ TEST(GaussianMixture, GaussianMixtureModel) { double midway = mu1 - mu0; auto eliminationResult = gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential(); - auto pMid = *eliminationResult->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid)); + auto pMid = std::dynamic_pointer_cast( + eliminationResult->at(0)->asDiscrete()); + EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid)); // Everywhere else, the result should be a sigmoid. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -90,7 +92,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = *std::dynamic_pointer_cast( + eliminationResult1->at(0)->asDiscrete()); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -99,7 +102,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)}); hfg1.push_back(mixing); auto eliminationResult2 = hfg1.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = *std::dynamic_pointer_cast( + eliminationResult2->at(0)->asDiscrete()); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } @@ -133,13 +137,14 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Eliminate the graph! auto eliminationResultMax = gfg.eliminateSequential(); - // Equality of posteriors asserts that the elimination is correct (same ratios - // for all modes) + // Equality of posteriors asserts that the elimination is correct + // (same ratios for all modes) EXPECT(assert_equal(expectedDiscretePosterior, eliminationResultMax->discretePosterior(vv))); - auto pMax = *eliminationResultMax->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4)); + auto pMax = *std::dynamic_pointer_cast( + eliminationResultMax->at(0)->asDiscrete()); + EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4)); // Everywhere else, the result should be a bell curve like function. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -149,7 +154,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = *std::dynamic_pointer_cast( + eliminationResult1->at(0)->asDiscrete()); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -158,7 +164,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) { m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)}); hfg.push_back(mixing); auto eliminationResult2 = hfg.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = *std::dynamic_pointer_cast( + eliminationResult2->at(0)->asDiscrete()); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } }