fix continuousKeySet method

release/4.3a0
Varun Agrawal 2023-07-06 22:22:51 -04:00
parent f751a5bfcf
commit 4e902fc8a7
3 changed files with 31 additions and 1 deletions

View File

@ -67,6 +67,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
for (const Key& key : p->continuousKeys()) { for (const Key& key : p->continuousKeys()) {
keys.insert(key); keys.insert(key);
} }
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
} }
} }
return keys; return keys;

View File

@ -18,7 +18,9 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/nonlinear/PriorFactor.h> #include <gtsam/nonlinear/PriorFactor.h>
using namespace std; using namespace std;
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
HybridFactorGraph fg; HybridFactorGraph fg;
} }
/* ************************************************************************* */
// Test if methods to get keys work as expected.
TEST(HybridFactorGraph, Keys) {
HybridGaussianFactorGraph hfg;
// Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
// Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
// Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
KeySet expected_continuous{X(0), X(1)};
EXPECT(
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
KeySet expected_discrete{M(1)};
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Test resulting posterior Bayes net has correct size: // Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size()); EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): this test fails - no idea why. // Ratio test
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
} }