diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index d96a890f4..235ffc87f 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -67,6 +67,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const { for (const Key& key : p->continuousKeys()) { keys.insert(key); } + } else if (auto p = std::dynamic_pointer_cast(factor)) { + keys.insert(p->keys().begin(), p->keys().end()); } } return keys; diff --git a/gtsam/hybrid/tests/testHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridFactorGraph.cpp index f5b4ec0b1..33c0761eb 100644 --- a/gtsam/hybrid/tests/testHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridFactorGraph.cpp @@ -18,7 +18,9 @@ #include #include #include +#include #include +#include #include using namespace std; @@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) { HybridFactorGraph fg; } +/* ************************************************************************* */ +// Test if methods to get keys work as expected. +TEST(HybridFactorGraph, Keys) { + HybridGaussianFactorGraph hfg; + + // Add prior on x0 + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + + // Add factor between x0 and x1 + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + // Add a gaussian mixture factor ϕ(x1, c1) + DiscreteKey m1(M(1), 2); + DecisionTree dt( + M(1), std::make_shared(X(1), I_3x3, Z_3x1), + std::make_shared(X(1), I_3x3, Vector3::Ones())); + hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt)); + + KeySet expected_continuous{X(0), X(1)}; + EXPECT( + assert_container_equality(expected_continuous, hfg.continuousKeySet())); + + KeySet expected_discrete{M(1)}; + EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet())); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 8276264ae..1da897103 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // Test resulting posterior Bayes net has correct size: EXPECT_LONGS_EQUAL(8, posterior->size()); - // TODO(dellaert): this test fails - no idea why. + // Ratio test EXPECT(ratioTest(bn, measurements, *posterior)); }