From 30bf261c15d8190c9eea09057bf11269d38628c7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 21 Aug 2024 20:11:00 -0400 Subject: [PATCH] Tests which verify direct factor specification works well --- .../tests/testGaussianMixtureFactor.cpp | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 7926cdc5a..10541cf88 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -388,6 +388,151 @@ TEST(GaussianMixtureFactor, DifferentCovariances) { EXPECT(assert_equal(expected_m1, actual_m1)); } +HybridGaussianFactorGraph CreateFactorGraph(const gtsam::Values &values, + const std::vector &mus, + const std::vector &sigmas, + DiscreteKey &m1) { + auto model0 = noiseModel::Isotropic::Sigma(1, sigmas[0]); + auto model1 = noiseModel::Isotropic::Sigma(1, sigmas[1]); + auto prior_noise = noiseModel::Isotropic::Sigma(1, 1e-3); + + auto f0 = std::make_shared>(X(0), X(1), mus[0], model0) + ->linearize(values); + auto f1 = std::make_shared>(X(0), X(1), mus[1], model1) + ->linearize(values); + + // Create GaussianMixtureFactor + std::vector factors{f0, f1}; + AlgebraicDecisionTree logNormalizers( + {m1}, std::vector{ComputeLogNormalizer(model0), + ComputeLogNormalizer(model1)}); + GaussianMixtureFactor mixtureFactor({X(0), X(1)}, {m1}, factors, + logNormalizers); + + HybridGaussianFactorGraph hfg; + hfg.push_back(mixtureFactor); + + hfg.push_back(PriorFactor(X(0), values.at(X(0)), prior_noise) + .linearize(values)); + + return hfg; +} + +TEST(GaussianMixtureFactor, DifferentMeansFG) { + DiscreteKey m1(M(1), 2); + + Values values; + double x1 = 0.0, x2 = 1.75; + values.insert(X(0), x1); + values.insert(X(1), x2); + + std::vector mus = {0.0, 2.0}, sigmas = {1e-0, 1e-0}; + + HybridGaussianFactorGraph hfg = CreateFactorGraph(values, mus, sigmas, m1); + + { + auto bn = hfg.eliminateSequential(); + HybridValues actual = bn->optimize(); + + HybridValues expected( + VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(-1.75)}}, + DiscreteValues{{M(1), 0}}); + + EXPECT(assert_equal(expected, actual)); + + { + DiscreteValues dv{{M(1), 0}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(0.69314718056, error, 1e-9); + } + { + DiscreteValues dv{{M(1), 1}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(0.69314718056, error, 1e-9); + } + } + + { + auto prior_noise = noiseModel::Isotropic::Sigma(1, 1e-3); + hfg.push_back( + PriorFactor(X(1), mus[1], prior_noise).linearize(values)); + + auto bn = hfg.eliminateSequential(); + HybridValues actual = bn->optimize(); + + HybridValues expected( + VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(0.25)}}, + DiscreteValues{{M(1), 1}}); + + EXPECT(assert_equal(expected, actual)); + + { + DiscreteValues dv{{M(1), 0}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(2.12692448787, error, 1e-9); + } + { + DiscreteValues dv{{M(1), 1}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(0.126928487854, error, 1e-9); + } + } +} + +/* ************************************************************************* */ +/** + * @brief Test components with differing covariances. + * The factor graph is + * *-X1-*-X2 + * | + * M1 + */ +TEST(GaussianMixtureFactor, DifferentCovariancesFG) { + DiscreteKey m1(M(1), 2); + + Values values; + double x1 = 1.0, x2 = 1.0; + values.insert(X(0), x1); + values.insert(X(1), x2); + + std::vector mus = {0.0, 0.0}, sigmas = {1e2, 1e-2}; + + // Create FG with GaussianMixtureFactor and prior on X1 + HybridGaussianFactorGraph mixture_fg = + CreateFactorGraph(values, mus, sigmas, m1); + + auto hbn = mixture_fg.eliminateSequential(); + + VectorValues cv; + cv.insert(X(0), Vector1(0.0)); + cv.insert(X(1), Vector1(0.0)); + + // Check that the error values at the MLE point μ. + AlgebraicDecisionTree errorTree = hbn->errorTree(cv); + hbn->errorTree(cv).print(); + hbn2->errorTree(cv).print(); + + DiscreteValues dv0{{M(1), 0}}; + DiscreteValues dv1{{M(1), 1}}; + + // regression + EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9); + EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9); + + DiscreteConditional expected_m1(m1, "0.5/0.5"); + DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); + + EXPECT(assert_equal(expected_m1, actual_m1)); +} + /* ************************************************************************* */ int main() { TestResult tr;