diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 0a7dc72f4..78e254aed 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -16,10 +16,10 @@ */ #include +#include #include #include #include -#include #include @@ -32,7 +32,8 @@ static constexpr bool debug = false; /* ************************************************************************* */ struct TestFixture { - vector keys; + DiscreteKeys keys; + std::vector assignments; DiscreteBayesNet bayesNet; boost::shared_ptr bayesTree; @@ -47,6 +48,9 @@ struct TestFixture { keys.push_back(key_i); } + // Enumerate all assignments. + assignments = DiscreteValues::CartesianProduct(keys); + // Create thin-tree Bayesnet. bayesNet.add(keys[14] % "1/3"); @@ -74,9 +78,9 @@ struct TestFixture { }; /* ************************************************************************* */ +// Check that BN and BT give the same answer on all configurations TEST(DiscreteBayesTree, ThinTree) { - const TestFixture self; - const auto& keys = self.keys; + TestFixture self; if (debug) { GTSAM_PRINT(self.bayesNet); @@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) { EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = self.bayesTree->roots().front(); - - // Check whether BN and BT give the same answer on all configurations - auto allPosbValues = DiscreteValues::CartesianProduct( - keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] & - keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] & - keys[14]); - for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteValues x = allPosbValues[i]; + for (const auto& x : self.assignments) { double expected = self.bayesNet.evaluate(x); double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } +} - // Calculate all some marginals for DiscreteValues==all1 - Vector marginals = Vector::Zero(15); - double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, - joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, - joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0, - joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; - for (size_t i = 0; i < allPosbValues.size(); ++i) { - DiscreteValues x = allPosbValues[i]; +/* ************************************************************************* */ +// Check calculation of separator marginals +TEST(DiscreteBayesTree, separatorMarginal) { + TestFixture self; + + // Calculate some marginals for DiscreteValues==all1 + double marginal_14 = 0, joint_8_12 = 0; + for (auto& x : self.assignments) { double px = self.bayesTree->evaluate(x); - for (size_t i = 0; i < 15; i++) - if (x[i]) marginals[i] += px; - if (x[12] && x[14]) { - joint_12_14 += px; - if (x[9]) joint_9_12_14 += px; - if (x[8]) joint_8_12_14 += px; - } if (x[8] && x[12]) joint_8_12 += px; - if (x[2]) { - if (x[8]) joint82 += px; - if (x[1]) joint12 += px; - } - if (x[4]) { - if (x[2]) joint24 += px; - if (x[5]) joint45 += px; - if (x[6]) joint46 += px; - if (x[11]) joint_4_11 += px; - } + if (x[14]) marginal_14 += px; + } + DiscreteValues all1 = self.assignments.back(); + + // check separator marginal P(S0) + auto clique = (*self.bayesTree)[0]; + DiscreteFactorGraph separatorMarginal0 = + clique->separatorMarginal(EliminateDiscrete); + DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + + // check separator marginal P(S9), should be P(14) + clique = (*self.bayesTree)[9]; + DiscreteFactorGraph separatorMarginal9 = + clique->separatorMarginal(EliminateDiscrete); + DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9); + + // check separator marginal of root, should be empty + clique = (*self.bayesTree)[11]; + DiscreteFactorGraph separatorMarginal11 = + clique->separatorMarginal(EliminateDiscrete); + LONGS_EQUAL(0, separatorMarginal11.size()); +} + +/* ************************************************************************* */ +// Check shortcuts in the tree +TEST(DiscreteBayesTree, shortcut) { + TestFixture self; + + // Calculate some marginals for DiscreteValues==all1 + double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0, + joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; + for (auto& x : self.assignments) { + double px = self.bayesTree->evaluate(x); if (x[11] && x[13]) { joint_11_13 += px; if (x[8] && x[12]) joint_8_11_12_13 += px; @@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) { } } } - DiscreteValues all1 = allPosbValues.back(); + DiscreteValues all1 = self.assignments.back(); - // check separator marginal P(S0) - auto clique = (*self.bayesTree)[0]; - DiscreteFactorGraph separatorMarginal0 = - clique->separatorMarginal(EliminateDiscrete); - DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); - - DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9); - DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9); - DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9); - - // check separator marginal P(S9), should be P(14) - clique = (*self.bayesTree)[9]; - DiscreteFactorGraph separatorMarginal9 = - clique->separatorMarginal(EliminateDiscrete); - DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); - - // check separator marginal of root, should be empty - clique = (*self.bayesTree)[11]; - DiscreteFactorGraph separatorMarginal11 = - clique->separatorMarginal(EliminateDiscrete); - LONGS_EQUAL(0, separatorMarginal11.size()); + auto R = self.bayesTree->roots().front(); // check shortcut P(S9||R) to root - clique = (*self.bayesTree)[9]; + auto clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); @@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) { shortcut.print("shortcut:"); } } +} + +/* ************************************************************************* */ +// Check all marginals +TEST(DiscreteBayesTree, marginalFactor) { + TestFixture self; + + Vector marginals = Vector::Zero(15); + for (size_t i = 0; i < self.assignments.size(); ++i) { + DiscreteValues& x = self.assignments[i]; + double px = self.bayesTree->evaluate(x); + for (size_t i = 0; i < 15; i++) + if (x[i]) marginals[i] += px; + } // Check all marginals - DiscreteFactor::shared_ptr marginalFactor; + DiscreteValues all1 = self.assignments.back(); for (size_t i = 0; i < 15; i++) { - marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); + auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } +} +/* ************************************************************************* */ +// Check a number of joint marginals. +TEST(DiscreteBayesTree, Joints) { + TestFixture self; + + // Calculate some marginals for DiscreteValues==all1 + Vector marginals = Vector::Zero(15); + double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0, + joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0; + for (size_t i = 0; i < self.assignments.size(); ++i) { + DiscreteValues& x = self.assignments[i]; + double px = self.bayesTree->evaluate(x); + for (size_t i = 0; i < 15; i++) + if (x[i]) marginals[i] += px; + if (x[12] && x[14]) { + joint_12_14 += px; + if (x[9]) joint_9_12_14 += px; + if (x[8]) joint_8_12_14 += px; + } + if (x[2]) { + if (x[8]) joint82 += px; + if (x[1]) joint12 += px; + } + if (x[4]) { + if (x[2]) joint24 += px; + if (x[5]) joint45 += px; + if (x[6]) joint46 += px; + if (x[11]) joint_4_11 += px; + } + } + + // regression tests: + DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9); + DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9); + DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9); + + DiscreteValues all1 = self.assignments.back(); DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) @@ -240,7 +285,7 @@ TEST(DiscreteBayesTree, ThinTree) { /* ************************************************************************* */ TEST(DiscreteBayesTree, Dot) { - const TestFixture self; + TestFixture self; string actual = self.bayesTree->dot(); EXPECT(actual == "digraph G{\n"