From b2fbbbb26fc994b04b8cc5bcf692cdfe2dedee81 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 22 Jan 2025 12:05:27 -0500 Subject: [PATCH] New DiscreteBayesTree tests --- .../discrete/tests/testDiscreteBayesTree.cpp | 35 +++++++++++++++++++ python/gtsam/tests/test_DiscreteBayesTree.py | 31 ++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 617eb7c9d..95d6f0370 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -369,6 +369,41 @@ TEST(DiscreteBayesTree, Lookup) { EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9); } +/* ************************************************************************* */ +// Test creating a Bayes tree directly from cliques +TEST(DiscreteBayesTree, DirectFromCliques) { + // Create a BayesNet + DiscreteBayesNet bayesNet; + DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2); + bayesNet.add(key0 % "1/3"); + bayesNet.add(key1 | key0 = "1/3 3/1"); + bayesNet.add(key2 | key1 = "3/1 3/1"); + + // Create cliques directly + auto clique2 = std::make_shared( + std::make_shared(key2 | key1 = "3/1 3/1")); + auto clique1 = std::make_shared( + std::make_shared(key1 | key0 = "1/3 3/1")); + auto clique0 = std::make_shared( + std::make_shared(key0 % "1/3")); + + // Create a BayesTree + DiscreteBayesTree bayesTree; + bayesTree.insertRoot(clique2); + bayesTree.addClique(clique1, clique2); + bayesTree.addClique(clique0, clique1); + + // Check that the BayesTree is correct + DiscreteValues values; + values[0] = 1; + values[1] = 1; + values[2] = 1; + + double expected = bayesNet.evaluate(values); + double actual = bayesTree.evaluate(values); + DOUBLES_EQUAL(expected, actual, 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index e08491fab..d327b5a1b 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase): values[X(3)] = 2 self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10... + def test_direct_from_cliques(self): + """Test creating a Bayes tree directly from cliques.""" + # Create a BayesNet + bayesNet = DiscreteBayesNet() + key0, key1, key2 = (0, 2), (1, 2), (2, 2) + bayesNet.add(key0, "1/3") + bayesNet.add(key1, [key0], "1/3 3/1") + bayesNet.add(key2, [key1], "3/1 3/1") + + # Create cliques directly + clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1")) + clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1")) + clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3")) + + # Create a BayesTree + bayesTree = gtsam.DiscreteBayesTree() + bayesTree.insertRoot(clique2) + bayesTree.addClique(clique1, clique2) + bayesTree.addClique(clique0, clique1) + + # Check that the BayesTree is correct + values = DiscreteValues() + values[0] = 1 + values[1] = 1 + values[2] = 1 + + expected = bayesNet.evaluate(values) + actual = bayesTree.evaluate(values) + self.assertAlmostEqual(expected, actual, places=9) + + if __name__ == "__main__": unittest.main()