diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 5fff6d423..ee7af73b9 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -150,7 +150,7 @@ TEST(DiscreteBayesNet, Sugar) { } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Dot) { +TEST(DiscreteBayesNet, Dot) { DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), Either(5, 2); diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 73f345151..d9f5f5df7 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -26,76 +26,89 @@ using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; - -static bool debug = false; +static constexpr bool debug = false; /* ************************************************************************* */ - -TEST_UNSAFE(DiscreteBayesTree, ThinTree) { - const int nrNodes = 15; - const size_t nrStates = 2; - - // define variables - vector key; - for (int i = 0; i < nrNodes; i++) { - DiscreteKey key_i(i, nrStates); - key.push_back(key_i); - } - - // create a thin-tree Bayesnet, a la Jean-Guillaume +struct TestFixture { + vector keys; DiscreteBayesNet bayesNet; - bayesNet.add(key[14] % "1/3"); + boost::shared_ptr bayesTree; - bayesNet.add(key[13] | key[14] = "1/3 3/1"); - bayesNet.add(key[12] | key[14] = "3/1 3/1"); + /** + * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student), + * and then create the Bayes tree from it. + */ + TestFixture() { + // Define variables. + for (int i = 0; i < 15; i++) { + DiscreteKey key_i(i, 2); + keys.push_back(key_i); + } - bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); - bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); + // Create thin-tree Bayesnet. + bayesNet.add(keys[14] % "1/3"); - bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); + bayesNet.add(keys[13] | keys[14] = "1/3 3/1"); + bayesNet.add(keys[12] | keys[14] = "3/1 3/1"); - bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); - bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); + bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4"); + bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1"); + + bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1"); + + bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1"); + bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1"); + + // Create a BayesTree out of the Bayes net. + bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); + } +}; + +/* ************************************************************************* */ +TEST(DiscreteBayesTree, ThinTree) { + const TestFixture self; + const auto& keys = self.keys; if (debug) { - GTSAM_PRINT(bayesNet); - bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); + GTSAM_PRINT(self.bayesNet); + self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } // create a BayesTree out of a Bayes net - auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); if (debug) { - GTSAM_PRINT(*bayesTree); - bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); + GTSAM_PRINT(*self.bayesTree); + self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } // Check frontals and parents for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { - auto clique_i = (*bayesTree)[i]; + auto clique_i = (*self.bayesTree)[i]; EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = bayesTree->roots().front(); + auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & - key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); + vector allPosbValues = + cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & + keys[5] & keys[6] & keys[7] & keys[8] & keys[9] & + keys[10] & keys[11] & keys[12] & keys[13] & keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; - double expected = bayesNet.evaluate(x); - double actual = bayesTree->evaluate(x); + double expected = self.bayesNet.evaluate(x); + double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } @@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; - double px = bayesTree->evaluate(x); + double px = self.bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; if (x[12] && x[14]) { @@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteValues all1 = allPosbValues.back(); // check separator marginal P(S0) - auto clique = (*bayesTree)[0]; + auto clique = (*self.bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - clique = (*bayesTree)[11]; + clique = (*self.bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - clique = (*bayesTree)[8]; + clique = (*self.bayesTree)[8]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - clique = (*bayesTree)[2]; + clique = (*self.bayesTree)[2]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - clique = (*bayesTree)[0]; + clique = (*self.bayesTree)[0]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); for (auto clique : cliques) { DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { @@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } @@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) - actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); // Check joint P(1, 2) - actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); // Check joint P(2, 4) - actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 5) - actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 6) - actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 11) - actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, Dot) { + const TestFixture self; + string actual = self.bayesTree->dot(); + EXPECT(actual == + "digraph G{\n" + "0[label=\"13,11,6,7\"];\n" + "0->1\n" + "1[label=\"14 : 11,13\"];\n" + "1->2\n" + "2[label=\"9,12 : 14\"];\n" + "2->3\n" + "3[label=\"3 : 9,12\"];\n" + "2->4\n" + "4[label=\"2 : 9,12\"];\n" + "2->5\n" + "5[label=\"8 : 12,14\"];\n" + "5->6\n" + "6[label=\"1 : 8,12\"];\n" + "5->7\n" + "7[label=\"0 : 8,12\"];\n" + "1->8\n" + "8[label=\"10 : 13,14\"];\n" + "8->9\n" + "9[label=\"5 : 10,13\"];\n" + "8->10\n" + "10[label=\"4 : 10,13\"];\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index e430b3fe4..a1a350ac2 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -68,8 +68,7 @@ namespace gtsam { /// @{ /// Output to graphviz format, stream version. - virtual void dot(std::ostream& os, const KeyFormatter& keyFormatter = - DefaultKeyFormatter) const; + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; /// Output to graphviz format string. std::string dot( diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 5b53a5719..9b937fefb 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -63,20 +63,40 @@ namespace gtsam { } /* ************************************************************************* */ - template - void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { - if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); - std::ofstream of(s.c_str()); - of<< "digraph G{\n"; - for(const sharedClique& root: roots_) - saveGraph(of, root, keyFormatter); - of<<"}"; + template + void BayesTree::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + if (roots_.empty()) + throw std::invalid_argument( + "the root of Bayes tree has not been initialized!"); + os << "digraph G{\n"; + for (const sharedClique& root : roots_) dot(os, root, keyFormatter); + os << "}"; + std::flush(os); + } + + /* ************************************************************************* */ + template + std::string BayesTree::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } /* ************************************************************************* */ - template - void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { + template + void BayesTree::dot(std::ostream& s, sharedClique clique, + const KeyFormatter& indexFormatter, + int parentnum) const { static int num = 0; bool first = true; std::stringstream out; @@ -107,7 +127,7 @@ namespace gtsam { for (sharedClique c : clique->children) { num++; - saveGraph(s, c, indexFormatter, parentnum); + dot(s, c, indexFormatter, parentnum); } } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index cc003d8dc..3741e5a1c 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -182,13 +182,17 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /** - * Read only with side effects - */ + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /** saves the Tree to a text file in GraphViz format */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ @@ -236,8 +240,8 @@ namespace gtsam { protected: /** private helper method for saving the Tree to a text file in GraphViz format */ - void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, - int parentnum = 0) const; + void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, + int parentnum = 0) const; /** Gather data on a single clique */ void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const; @@ -249,7 +253,7 @@ namespace gtsam { void fillNodesIndex(const sharedClique& subtree); // Friend JunctionTree because it directly fills roots and nodes index. - template friend class EliminatableClusterTree; + template friend class EliminatableClusterTree; private: /** Serialization function */