From 3d55fe0d378dff654d8520e9d174b93be22948d5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 22:56:32 -0700 Subject: [PATCH] Finish tests --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 36 ++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 8988d1e62..ee47a698a 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -65,13 +65,18 @@ TEST(HybridBayesNet, Add) { // Test API for a pure discrete Bayes net P(Asia). TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; - bayesNet.emplace_shared(Asia, "4/6"); + const auto pAsia = std::make_shared(Asia, "4/6"); + bayesNet.push_back(pAsia); HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}}; // choose GaussianBayesNet empty; EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + // evaluate EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); @@ -88,18 +93,35 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { // prune EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); - // EXPECT(assert_equal(bayesNet, bayesNet.prune(1))); Should fail !!! - // EXPECT(assert_equal(bayesNet, bayesNet.prune(0))); Should fail !!! + EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); // errorTree AlgebraicDecisionTree actual = bayesNet.errorTree({}); - AlgebraicDecisionTree expected( + AlgebraicDecisionTree expectedErrorTree( {Asia}, std::vector{-log(0.4), -log(0.6)}); - EXPECT(assert_equal(expected, actual)); + EXPECT(assert_equal(expectedErrorTree, actual)); // error EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logDiscretePosteriorPrime, TODO: useless as -errorTree? + AlgebraicDecisionTree expected({Asia}, + std::vector{log(0.4), log(0.6)}); + EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime({}))); + + // logProbability + EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // discretePosterior + AlgebraicDecisionTree expectedPosterior({Asia}, + std::vector{0.4, 0.6}); + EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); + + // toFactorGraph + HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({}); + EXPECT(assert_equal(expectedFG, fg)); } /* ****************************************************************************/ @@ -358,7 +380,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials); // Prune! - posterior->prune(maxNrLeaves); + auto pruned = posterior->prune(maxNrLeaves); // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, @@ -375,7 +397,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete(); + auto pruned_discrete_conditionals = pruned.at(4)->asDiscrete(); auto discrete_conditional_tree = std::dynamic_pointer_cast( pruned_discrete_conditionals);