dot methods for Bayes tree

release/4.3a0
Frank Dellaert 2021-12-19 09:37:30 -05:00
parent 352268448c
commit d85a1e68e4
5 changed files with 146 additions and 80 deletions

View File

@ -150,7 +150,7 @@ TEST(DiscreteBayesNet, Sugar) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Dot) { TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2); Either(5, 2);

View File

@ -26,76 +26,89 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <vector> #include <vector>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static constexpr bool debug = false;
static bool debug = false;
/* ************************************************************************* */ /* ************************************************************************* */
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteBayesNet bayesNet;
boost::shared_ptr<DiscreteBayesTree> bayesTree;
TEST_UNSAFE(DiscreteBayesTree, ThinTree) { /**
const int nrNodes = 15; * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
const size_t nrStates = 2; * and then create the Bayes tree from it.
*/
// define variables TestFixture() {
vector<DiscreteKey> key; // Define variables.
for (int i = 0; i < nrNodes; i++) { for (int i = 0; i < 15; i++) {
DiscreteKey key_i(i, nrStates); DiscreteKey key_i(i, 2);
key.push_back(key_i); keys.push_back(key_i);
} }
// create a thin-tree Bayesnet, a la Jean-Guillaume // Create thin-tree Bayesnet.
DiscreteBayesNet bayesNet; bayesNet.add(keys[14] % "1/3");
bayesNet.add(key[14] % "1/3");
bayesNet.add(key[13] | key[14] = "1/3 3/1"); bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
bayesNet.add(key[12] | key[14] = "3/1 3/1"); bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
// Create a BayesTree out of the Bayes net.
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
}
};
/* ************************************************************************* */
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
if (debug) { if (debug) {
GTSAM_PRINT(bayesNet); GTSAM_PRINT(self.bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
} }
// create a BayesTree out of a Bayes net // create a BayesTree out of a Bayes net
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
if (debug) { if (debug) {
GTSAM_PRINT(*bayesTree); GTSAM_PRINT(*self.bayesTree);
bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
} }
// Check frontals and parents // Check frontals and parents
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
auto clique_i = (*bayesTree)[i]; auto clique_i = (*self.bayesTree)[i];
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
} }
auto R = bayesTree->roots().front(); auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations // Check whether BN and BT give the same answer on all configurations
vector<DiscreteValues> allPosbValues = cartesianProduct( vector<DiscreteValues> allPosbValues =
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i]; DiscreteValues x = allPosbValues[i];
double expected = bayesNet.evaluate(x); double expected = self.bayesNet.evaluate(x);
double actual = bayesTree->evaluate(x); double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9); DOUBLES_EQUAL(expected, actual, 1e-9);
} }
@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) { for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i]; DiscreteValues x = allPosbValues[i];
double px = bayesTree->evaluate(x); double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++) for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px; if (x[i]) marginals[i] += px;
if (x[12] && x[14]) { if (x[12] && x[14]) {
@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteValues all1 = allPosbValues.back(); DiscreteValues all1 = allPosbValues.back();
// check separator marginal P(S0) // check separator marginal P(S0)
auto clique = (*bayesTree)[0]; auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 = DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete); clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14) // check separator marginal P(S9), should be P(14)
clique = (*bayesTree)[9]; clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 = DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete); clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty // check separator marginal of root, should be empty
clique = (*bayesTree)[11]; clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 = DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete); clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size()); LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root // check shortcut P(S9||R) to root
clique = (*bayesTree)[9]; clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size()); LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S8||R) to root // check shortcut P(S8||R) to root
clique = (*bayesTree)[8]; clique = (*self.bayesTree)[8];
shortcut = clique->shortcut(R, EliminateDiscrete); shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S2||R) to root // check shortcut P(S2||R) to root
clique = (*bayesTree)[2]; clique = (*self.bayesTree)[2];
shortcut = clique->shortcut(R, EliminateDiscrete); shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S0||R) to root // check shortcut P(S0||R) to root
clique = (*bayesTree)[0]; clique = (*self.bayesTree)[0];
shortcut = clique->shortcut(R, EliminateDiscrete); shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// calculate all shortcuts to root // calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) { for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) { if (debug) {
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
// Check all marginals // Check all marginals
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) { for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1); double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9); DOUBLES_EQUAL(marginals[i], actual, 1e-9);
} }
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteBayesNet::shared_ptr actualJoint; DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2) // Check joint P(8, 2)
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
// Check joint P(1, 2) // Check joint P(1, 2)
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
// Check joint P(2, 4) // Check joint P(2, 4)
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 5) // Check joint P(4, 5)
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 6) // Check joint P(4, 6)
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 11) // Check joint P(4, 11)
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
} }
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13,11,6,7\"];\n"
"0->1\n"
"1[label=\"14 : 11,13\"];\n"
"1->2\n"
"2[label=\"9,12 : 14\"];\n"
"2->3\n"
"3[label=\"3 : 9,12\"];\n"
"2->4\n"
"4[label=\"2 : 9,12\"];\n"
"2->5\n"
"5[label=\"8 : 12,14\"];\n"
"5->6\n"
"6[label=\"1 : 8,12\"];\n"
"5->7\n"
"7[label=\"0 : 8,12\"];\n"
"1->8\n"
"8[label=\"10 : 13,14\"];\n"
"8->9\n"
"9[label=\"5 : 10,13\"];\n"
"8->10\n"
"10[label=\"4 : 10,13\"];\n"
"}");
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -68,8 +68,7 @@ namespace gtsam {
/// @{ /// @{
/// Output to graphviz format, stream version. /// Output to graphviz format, stream version.
virtual void dot(std::ostream& os, const KeyFormatter& keyFormatter = void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
DefaultKeyFormatter) const;
/// Output to graphviz format string. /// Output to graphviz format string.
std::string dot( std::string dot(

View File

@ -64,19 +64,39 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template <class CLIQUE> template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { void BayesTree<CLIQUE>::dot(std::ostream& os,
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); const KeyFormatter& keyFormatter) const {
std::ofstream of(s.c_str()); if (roots_.empty())
of<< "digraph G{\n"; throw std::invalid_argument(
for(const sharedClique& root: roots_) "the root of Bayes tree has not been initialized!");
saveGraph(of, root, keyFormatter); os << "digraph G{\n";
of<<"}"; for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
os << "}";
std::flush(os);
}
/* ************************************************************************* */
template <class CLIQUE>
std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
of.close(); of.close();
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CLIQUE> template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
const KeyFormatter& indexFormatter,
int parentnum) const {
static int num = 0; static int num = 0;
bool first = true; bool first = true;
std::stringstream out; std::stringstream out;
@ -107,7 +127,7 @@ namespace gtsam {
for (sharedClique c : clique->children) { for (sharedClique c : clique->children) {
num++; num++;
saveGraph(s, c, indexFormatter, parentnum); dot(s, c, indexFormatter, parentnum);
} }
} }

View File

@ -182,12 +182,16 @@ namespace gtsam {
*/ */
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
/** /// Output to graphviz format, stream version.
* Read only with side effects void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
*/
/** saves the Tree to a text file in GraphViz format */ /// Output to graphviz format string.
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
@ -236,7 +240,7 @@ namespace gtsam {
protected: protected:
/** private helper method for saving the Tree to a text file in GraphViz format */ /** private helper method for saving the Tree to a text file in GraphViz format */
void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
int parentnum = 0) const; int parentnum = 0) const;
/** Gather data on a single clique */ /** Gather data on a single clique */
@ -249,7 +253,7 @@ namespace gtsam {
void fillNodesIndex(const sharedClique& subtree); void fillNodesIndex(const sharedClique& subtree);
// Friend JunctionTree because it directly fills roots and nodes index. // Friend JunctionTree because it directly fills roots and nodes index.
template<class BAYESRTEE, class GRAPH> friend class EliminatableClusterTree; template<class BAYESTREE, class GRAPH> friend class EliminatableClusterTree;
private: private:
/** Serialization function */ /** Serialization function */