diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 784d9c95f..327b5b3d0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -343,11 +343,20 @@ TEST(HybridBayesNet, Optimize) { } /* ****************************************************************************/ -// Test Bayes net error -TEST(HybridBayesNet, Pruning) { - // Create switching network with three continuous variables and two discrete: - // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) - Switching s(3); +namespace hbn_error { +// Create switching network with three continuous variables and two discrete: +// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) +Switching s(3); + +// The true discrete assignment +const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; + +} // namespace hbn_error + +/* ****************************************************************************/ +// Test Bayes net error and log-probability +TEST(HybridBayesNet, Error) { + using namespace hbn_error; HybridBayesNet::shared_ptr posterior = s.linearizedFactorGraph().eliminateSequential(); @@ -366,7 +375,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT(assert_equal(expected, discretePosterior, 1e-6)); // Verify logProbability computation and check specific logProbability value - const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; double logProbability = 0; logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); @@ -390,17 +398,32 @@ TEST(HybridBayesNet, Pruning) { // Check agreement with discrete posterior double density = exp(logProbability + negLogConstant) / normalizer; EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); +} + +/* ****************************************************************************/ +// Test Bayes net error and log-probability after pruning +TEST(HybridBayesNet, ErrorAfterPruning) { + using namespace hbn_error; + + HybridBayesNet::shared_ptr posterior = + s.linearizedFactorGraph().eliminateSequential(); + EXPECT_LONGS_EQUAL(5, posterior->size()); + + // Optimize + HybridValues delta = posterior->optimize(); // Prune and get probabilities - auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); + HybridBayesNet prunedBayesNet = posterior->prune(2); + AlgebraicDecisionTree prunedTree = + prunedBayesNet.discretePosterior(delta.continuous()); - // Regression test on pruned logProbability tree + // Regression test on pruned probability tree std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; AlgebraicDecisionTree expected_pruned(s.modes, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); - // Regression + // Regression to check specific logProbability value + const HybridValues hybridValues{delta.continuous(), discrete_values}; double pruned_logProbability = 0; pruned_logProbability += prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); @@ -423,24 +446,6 @@ TEST(HybridBayesNet, Pruning) { EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9); } -/* ****************************************************************************/ -// Test Bayes net pruning -TEST(HybridBayesNet, Prune) { - Switching s(4); - - HybridBayesNet::shared_ptr posterior = - s.linearizedFactorGraph().eliminateSequential(); - EXPECT_LONGS_EQUAL(7, posterior->size()); - - HybridValues delta = posterior->optimize(); - - auto prunedBayesNet = posterior->prune(2); - HybridValues pruned_delta = prunedBayesNet.optimize(); - - EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete())); - EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous())); -} - /* ****************************************************************************/ // Test Bayes net updateDiscreteConditionals TEST(HybridBayesNet, UpdateDiscreteConditionals) {