From 1673c47ea0b232e2c00adf9d60cc9b67c50ead3b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 29 Aug 2024 14:23:18 -0400 Subject: [PATCH] unpack HybridConditional in errorTree --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 5 ++ .../tests/testHybridGaussianFactorGraph.cpp | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 7ac6cef98..6c3442f97 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -551,6 +551,11 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( // TODO(dellaert): just use a virtual method defined in HybridFactor. AlgebraicDecisionTree factor_error; + auto f = factor; + if (auto hc = dynamic_pointer_cast(factor)) { + f = hc->inner(); + } + if (auto gaussianMixture = dynamic_pointer_cast(f)) { // Compute factor error and add it. error_tree = error_tree + gaussianMixture->errorTree(continuousValues); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 5be2f2742..68b3b8215 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -598,6 +598,57 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { EXPECT(assert_equal(expected_probs, probs, 1e-7)); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph errorTree when there is a HybridConditional in the graph +TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) { + using symbol_shorthand::F; + + DiscreteKey m1(M(1), 2); + Key z0 = Z(0), f01 = F(0); + Key x0 = X(0), x1 = X(1); + + HybridBayesNet hbn; + + auto prior_model = noiseModel::Isotropic::Sigma(1, 1e-1); + auto measurement_model = noiseModel::Isotropic::Sigma(1, 2.0); + + // Set a prior P(x0) at x0=0 + hbn.emplace_back( + new GaussianConditional(x0, Vector1(0.0), I_1x1, prior_model)); + + // Add measurement P(z0 | x0) + hbn.emplace_back(new GaussianConditional(z0, Vector1(0.0), -I_1x1, x0, I_1x1, + measurement_model)); + + // Add hybrid motion model + double mu = 0.0; + double sigma0 = 1e2, sigma1 = 1e-2; + auto model0 = noiseModel::Isotropic::Sigma(1, sigma0); + auto model1 = noiseModel::Isotropic::Sigma(1, sigma1); + auto c0 = make_shared(f01, Vector1(mu), I_1x1, x1, I_1x1, + x0, -I_1x1, model0), + c1 = make_shared(f01, Vector1(mu), I_1x1, x1, I_1x1, + x0, -I_1x1, model1); + hbn.emplace_back(new GaussianMixture({f01}, {x0, x1}, {m1}, {c0, c1})); + + // Discrete uniform prior. + hbn.emplace_back(new DiscreteConditional(m1, "0.5/0.5")); + + VectorValues given; + given.insert(z0, Vector1(0.0)); + given.insert(f01, Vector1(0.0)); + auto gfg = hbn.toFactorGraph(given); + + VectorValues vv; + vv.insert(x0, Vector1(1.0)); + vv.insert(x1, Vector1(2.0)); + AlgebraicDecisionTree errorTree = gfg.errorTree(vv); + + // regression + AlgebraicDecisionTree expected(m1, 59.335390372, 5050.125); + EXPECT(assert_equal(expected, errorTree, 1e-9)); +} + /* ****************************************************************************/ // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment.