diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 5493ce53e..815e32560 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -167,6 +167,61 @@ TEST(HybridSmoother, ValidPruningError) { EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); } +/****************************************************************************/ +// Test if dead mode removal works. +TEST(HybridSmoother, DeadModeRemoval) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + + // Smoother with dead mode removal enabled. + HybridSmoother smoother(true); + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + + // std::cout << "\n\n\nk" << std::endl; + // GTSAM_PRINT(linearized); + smoother.update(linearized, maxNrLeaves); + + // Clear all the factors from the graph + graph.resize(0); + } + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /* ************************************************************************* */ int main() { TestResult tr;