diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 95d6f0370..bc205c96c 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -28,7 +28,7 @@ #include using namespace gtsam; -static constexpr bool debug = false; +static constexpr bool debug = true; /* ************************************************************************* */ struct TestFixture { @@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) { shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); - for (auto clique : cliques) { - DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); - if (debug) { + if (debug) { + // print all shortcuts to root + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); + for (auto clique : cliques) { + DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); clique.second->conditional_->printSignature(); shortcut.print("shortcut:"); } @@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) { TEST(DiscreteBayesTree, MarginalFactors) { TestFixture self; + // Caclulate marginals with brute force enumeration. Vector marginals = Vector::Zero(15); for (size_t i = 0; i < self.assignments.size(); ++i) { DiscreteValues& x = self.assignments[i]; @@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) { TEST(DiscreteBayesTree, Dot) { TestFixture self; std::string actual = self.bayesTree->dot(); + // print actual: + if (debug) std::cout << actual << std::endl; EXPECT(actual == "digraph G{\n" "0[label=\"13, 11, 6, 7\"];\n" @@ -374,18 +377,18 @@ TEST(DiscreteBayesTree, Lookup) { TEST(DiscreteBayesTree, DirectFromCliques) { // Create a BayesNet DiscreteBayesNet bayesNet; - DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2); - bayesNet.add(key0 % "1/3"); - bayesNet.add(key1 | key0 = "1/3 3/1"); - bayesNet.add(key2 | key1 = "3/1 3/1"); + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + bayesNet.add(A % "1/3"); + bayesNet.add(B | A = "1/3 3/1"); + bayesNet.add(C | B = "3/1 3/1"); // Create cliques directly auto clique2 = std::make_shared( - std::make_shared(key2 | key1 = "3/1 3/1")); + std::make_shared(C | B = "3/1 3/1")); auto clique1 = std::make_shared( - std::make_shared(key1 | key0 = "1/3 3/1")); + std::make_shared(B | A = "1/3 3/1")); auto clique0 = std::make_shared( - std::make_shared(key0 % "1/3")); + std::make_shared(A % "1/3")); // Create a BayesTree DiscreteBayesTree bayesTree; @@ -395,13 +398,13 @@ TEST(DiscreteBayesTree, DirectFromCliques) { // Check that the BayesTree is correct DiscreteValues values; - values[0] = 1; - values[1] = 1; - values[2] = 1; + values[A.first] = 1; + values[A.first] = 1; + values[A.first] = 1; - double expected = bayesNet.evaluate(values); - double actual = bayesTree.evaluate(values); - DOUBLES_EQUAL(expected, actual, 1e-9); + // Regression + double expected = .046875; + DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index d327b5a1b..e8943fc80 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -160,15 +160,15 @@ class TestDiscreteBayesNet(GtsamTestCase): """Test creating a Bayes tree directly from cliques.""" # Create a BayesNet bayesNet = DiscreteBayesNet() - key0, key1, key2 = (0, 2), (1, 2), (2, 2) - bayesNet.add(key0, "1/3") - bayesNet.add(key1, [key0], "1/3 3/1") - bayesNet.add(key2, [key1], "3/1 3/1") + A, B, C = (0, 2), (1, 2), (2, 2) + bayesNet.add(A, "1/3") + bayesNet.add(B, [A], "1/3 3/1") + bayesNet.add(C, [B], "3/1 3/1") # Create cliques directly - clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1")) - clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1")) - clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3")) + clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1")) + clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1")) + clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3")) # Create a BayesTree bayesTree = gtsam.DiscreteBayesTree() @@ -182,9 +182,9 @@ class TestDiscreteBayesNet(GtsamTestCase): values[1] = 1 values[2] = 1 - expected = bayesNet.evaluate(values) - actual = bayesTree.evaluate(values) - self.assertAlmostEqual(expected, actual, places=9) + # regression + expected = .046875 + self.assertAlmostEqual(expected, bayesNet.evaluate(values)) if __name__ == "__main__":