dot methods for Bayes tree

release/4.3a0
Frank Dellaert 2021-12-19 09:37:30 -05:00
parent 352268448c
commit d85a1e68e4
5 changed files with 146 additions and 80 deletions

View File

@ -150,7 +150,7 @@ TEST(DiscreteBayesNet, Sugar) {
}
/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Dot) {
TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);

View File

@ -26,76 +26,89 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <vector>
using namespace std;
using namespace gtsam;
static bool debug = false;
static constexpr bool debug = false;
/* ************************************************************************* */
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteBayesNet bayesNet;
boost::shared_ptr<DiscreteBayesTree> bayesTree;
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
/**
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
* and then create the Bayes tree from it.
*/
TestFixture() {
// Define variables.
for (int i = 0; i < 15; i++) {
DiscreteKey key_i(i, 2);
keys.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
DiscreteBayesNet bayesNet;
bayesNet.add(key[14] % "1/3");
// Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3");
bayesNet.add(key[13] | key[14] = "1/3 3/1");
bayesNet.add(key[12] | key[14] = "3/1 3/1");
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
// Create a BayesTree out of the Bayes net.
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
}
};
/* ************************************************************************* */
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
if (debug) {
GTSAM_PRINT(bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
GTSAM_PRINT(self.bayesNet);
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
if (debug) {
GTSAM_PRINT(*bayesTree);
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
GTSAM_PRINT(*self.bayesTree);
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
}
// Check frontals and parents
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
auto clique_i = (*bayesTree)[i];
auto clique_i = (*self.bayesTree)[i];
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
}
auto R = bayesTree->roots().front();
auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations
vector<DiscreteValues> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
vector<DiscreteValues> allPosbValues =
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double expected = bayesNet.evaluate(x);
double actual = bayesTree->evaluate(x);
double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double px = bayesTree->evaluate(x);
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteValues all1 = allPosbValues.back();
// check separator marginal P(S0)
auto clique = (*bayesTree)[0];
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*bayesTree)[11];
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S8||R) to root
clique = (*bayesTree)[8];
clique = (*self.bayesTree)[8];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S2||R) to root
clique = (*bayesTree)[2];
clique = (*self.bayesTree)[2];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S0||R) to root
clique = (*bayesTree)[0];
clique = (*self.bayesTree)[0];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2)
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
// Check joint P(1, 2)
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
// Check joint P(2, 4)
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 5)
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 6)
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 11)
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13,11,6,7\"];\n"
"0->1\n"
"1[label=\"14 : 11,13\"];\n"
"1->2\n"
"2[label=\"9,12 : 14\"];\n"
"2->3\n"
"3[label=\"3 : 9,12\"];\n"
"2->4\n"
"4[label=\"2 : 9,12\"];\n"
"2->5\n"
"5[label=\"8 : 12,14\"];\n"
"5->6\n"
"6[label=\"1 : 8,12\"];\n"
"5->7\n"
"7[label=\"0 : 8,12\"];\n"
"1->8\n"
"8[label=\"10 : 13,14\"];\n"
"8->9\n"
"9[label=\"5 : 10,13\"];\n"
"8->10\n"
"10[label=\"4 : 10,13\"];\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -68,8 +68,7 @@ namespace gtsam {
/// @{
/// Output to graphviz format, stream version.
virtual void dot(std::ostream& os, const KeyFormatter& keyFormatter =
DefaultKeyFormatter) const;
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(

View File

@ -64,19 +64,39 @@ namespace gtsam {
/* ************************************************************************* */
template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
std::ofstream of(s.c_str());
of<< "digraph G{\n";
for(const sharedClique& root: roots_)
saveGraph(of, root, keyFormatter);
of<<"}";
void BayesTree<CLIQUE>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
if (roots_.empty())
throw std::invalid_argument(
"the root of Bayes tree has not been initialized!");
os << "digraph G{\n";
for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
os << "}";
std::flush(os);
}
/* ************************************************************************* */
template <class CLIQUE>
std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
of.close();
}
/* ************************************************************************* */
template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
const KeyFormatter& indexFormatter,
int parentnum) const {
static int num = 0;
bool first = true;
std::stringstream out;
@ -107,7 +127,7 @@ namespace gtsam {
for (sharedClique c : clique->children) {
num++;
saveGraph(s, c, indexFormatter, parentnum);
dot(s, c, indexFormatter, parentnum);
}
}

View File

@ -182,12 +182,16 @@ namespace gtsam {
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
/**
* Read only with side effects
*/
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/** saves the Tree to a text file in GraphViz format */
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
/// @name Advanced Interface
@ -236,7 +240,7 @@ namespace gtsam {
protected:
/** private helper method for saving the Tree to a text file in GraphViz format */
void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
int parentnum = 0) const;
/** Gather data on a single clique */
@ -249,7 +253,7 @@ namespace gtsam {
void fillNodesIndex(const sharedClique& subtree);
// Friend JunctionTree because it directly fills roots and nodes index.
template<class BAYESRTEE, class GRAPH> friend class EliminatableClusterTree;
template<class BAYESTREE, class GRAPH> friend class EliminatableClusterTree;
private:
/** Serialization function */