diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 86dcd48e4..3ddad23ff 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -166,8 +166,8 @@ TEST(HybridBayesNet, Tiny) { // prune auto pruned = bayesNet.prune(1); - CHECK(pruned.at(1)->asHybrid()); - EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents()); + CHECK(pruned.at(0)->asHybrid()); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT(!pruned.equals(bayesNet)); // error @@ -437,20 +437,21 @@ TEST(HybridBayesNet, RemoveDeadNodes) { const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); + // First conditional is still the same: P( x0 | x1 m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + + // Check that hybrid conditional that only depend on M1 + // is now Gaussian and not Hybrid + EXPECT(prunedBayesNet.at(1)->isContinuous()); + + // Third conditional is still Hybrid: P( x1 | m0 m1) -> P( x1 | m0) + EXPECT(prunedBayesNet.at(0)->isHybrid()); + // Check that discrete joint only has M0 and not (M0, M1) // since M0 is removed - KeyVector actual_keys = prunedBayesNet.at(0)->asDiscrete()->keys(); - EXPECT(KeyVector{M(0)} == actual_keys); - - // Check that hybrid conditionals that only depend on M1 - // are now Gaussian and not Hybrid - EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isDiscrete()); - EXPECT(prunedBayesNet.at(2)->isHybrid()); - // Only P(X2 | X1, M1) depends on M1, - // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(3)->isContinuous()); - EXPECT(prunedBayesNet.at(4)->isHybrid()); + auto joint = prunedBayesNet.at(3)->asDiscrete(); + EXPECT(joint); + EXPECT(joint->keys() == KeyVector{M(0)}); } /* ****************************************************************************/ @@ -479,13 +480,13 @@ TEST(HybridBayesNet, ErrorAfterPruning) { const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += - prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + prunedBayesNet.at(0)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); pruned_logProbability += prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); pruned_logProbability += - prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + prunedBayesNet.at(3)->asDiscrete()->logProbability(hybridValues); double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); @@ -548,8 +549,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - CHECK(pruned.at(0)->asDiscrete()); - auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete(); + CHECK(pruned.at(4)->asDiscrete()); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals); diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 3a0f376cc..97a302faf 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -103,7 +103,7 @@ TEST(HybridSmoother, IncrementalSmoother) { } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -157,7 +157,7 @@ TEST(HybridSmoother, ValidPruningError) { } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment.