From f62805f8b38b5bfc24986c535310cbc34d7a241d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 23 Feb 2024 12:49:51 -0500 Subject: [PATCH] add method to select underlying continuous Gaussian graph given discrete assignment --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 24 ++++++++- gtsam/hybrid/HybridGaussianFactorGraph.h | 4 ++ .../tests/testHybridGaussianFactorGraph.cpp | 52 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ea8bd0b05..32cdddec6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -296,7 +296,8 @@ static std::shared_ptr createDiscreteFactor( // Logspace version of: // exp(-factor->error(kEmpty)) * conditional->normalizationConstant(); - // We take negative of the logNormalizationConstant `log(1/k)` to get `log(k)`. + // We take negative of the logNormalizationConstant `log(1/k)` + // to get `log(k)`. return -factor->error(kEmpty) + (-conditional->logNormalizationConstant()); }; @@ -326,6 +327,7 @@ static std::shared_ptr createGaussianMixtureFactor( auto hf = std::dynamic_pointer_cast(factor); if (!hf) throw std::runtime_error("Expected HessianFactor!"); // Add 2.0 term since the constant term will be premultiplied by 0.5 + // as per the Hessian definition hf->constantTerm() += 2.0 * conditional->logNormalizationConstant(); } return factor; @@ -563,4 +565,24 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( return prob_tree; } +/* ************************************************************************ */ +GaussianFactorGraph HybridGaussianFactorGraph::operator()( + const DiscreteValues &assignment) const { + GaussianFactorGraph gfg; + for (auto &&f : *this) { + if (auto gf = std::dynamic_pointer_cast(f)) { + gfg.push_back(gf); + } else if (auto gc = std::dynamic_pointer_cast(f)) { + gfg.push_back(gf); + } else if (auto gmf = std::dynamic_pointer_cast(f)) { + gfg.push_back((*gmf)(assignment)); + } else if (auto gm = dynamic_pointer_cast(f)) { + gfg.push_back((*gm)(assignment)); + } else { + continue; + } + } + return gfg; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 8a86a7335..415464ecd 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -210,6 +210,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph GaussianFactorGraphTree assembleGraphTree() const; /// @} + + /// Get the GaussianFactorGraph at a given discrete assignment. + GaussianFactorGraph operator()(const DiscreteValues& assignment) const; + }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 5be2f2742..b97dcef72 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -490,6 +490,58 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { } } +/* ****************************************************************************/ +// Select a particular continuous factor graph given a discrete assignment +TEST(HybridGaussianFactorGraph, DiscreteSelection) { + Switching s(3); + + HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + + DiscreteValues dv00{{M(0), 0}, {M(1), 0}}; + GaussianFactorGraph continuous_00 = graph(dv00); + GaussianFactorGraph expected_00; + expected_00.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10))); + expected_00.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1))); + expected_00.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1))); + expected_00.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10))); + expected_00.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10))); + + EXPECT(assert_equal(expected_00, continuous_00)); + + DiscreteValues dv01{{M(0), 0}, {M(1), 1}}; + GaussianFactorGraph continuous_01 = graph(dv01); + GaussianFactorGraph expected_01; + expected_01.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10))); + expected_01.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1))); + expected_01.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0))); + expected_01.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10))); + expected_01.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10))); + + EXPECT(assert_equal(expected_01, continuous_01)); + + DiscreteValues dv10{{M(0), 1}, {M(1), 0}}; + GaussianFactorGraph continuous_10 = graph(dv10); + GaussianFactorGraph expected_10; + expected_10.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10))); + expected_10.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0))); + expected_10.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1))); + expected_10.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10))); + expected_10.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10))); + + EXPECT(assert_equal(expected_10, continuous_10)); + + DiscreteValues dv11{{M(0), 1}, {M(1), 1}}; + GaussianFactorGraph continuous_11 = graph(dv11); + GaussianFactorGraph expected_11; + expected_11.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10))); + expected_11.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0))); + expected_11.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0))); + expected_11.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10))); + expected_11.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10))); + + EXPECT(assert_equal(expected_11, continuous_11)); +} + /* ************************************************************************* */ TEST(HybridGaussianFactorGraph, optimize) { HybridGaussianFactorGraph hfg;