dot methods for Bayes tree
parent
352268448c
commit
d85a1e68e4
|
@ -150,7 +150,7 @@ TEST(DiscreteBayesNet, Sugar) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(DiscreteBayesNet, Dot) {
|
||||
TEST(DiscreteBayesNet, Dot) {
|
||||
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
||||
Either(5, 2);
|
||||
|
||||
|
|
|
@ -26,76 +26,89 @@ using namespace boost::assign;
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static bool debug = false;
|
||||
static constexpr bool debug = false;
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
||||
// define variables
|
||||
vector<DiscreteKey> key;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey key_i(i, nrStates);
|
||||
key.push_back(key_i);
|
||||
}
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
struct TestFixture {
|
||||
vector<DiscreteKey> keys;
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(key[14] % "1/3");
|
||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||
|
||||
bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
||||
bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
||||
/**
|
||||
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
|
||||
* and then create the Bayes tree from it.
|
||||
*/
|
||||
TestFixture() {
|
||||
// Define variables.
|
||||
for (int i = 0; i < 15; i++) {
|
||||
DiscreteKey key_i(i, 2);
|
||||
keys.push_back(key_i);
|
||||
}
|
||||
|
||||
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||
// Create thin-tree Bayesnet.
|
||||
bayesNet.add(keys[14] % "1/3");
|
||||
|
||||
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
|
||||
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
|
||||
|
||||
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
// Create a BayesTree out of the Bayes net.
|
||||
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, ThinTree) {
|
||||
const TestFixture self;
|
||||
const auto& keys = self.keys;
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
GTSAM_PRINT(self.bayesNet);
|
||||
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
if (debug) {
|
||||
GTSAM_PRINT(*bayesTree);
|
||||
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
GTSAM_PRINT(*self.bayesTree);
|
||||
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
}
|
||||
|
||||
// Check frontals and parents
|
||||
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
||||
auto clique_i = (*bayesTree)[i];
|
||||
auto clique_i = (*self.bayesTree)[i];
|
||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||
}
|
||||
|
||||
auto R = bayesTree->roots().front();
|
||||
auto R = self.bayesTree->roots().front();
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
vector<DiscreteValues> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
||||
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
vector<DiscreteValues> allPosbValues =
|
||||
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
|
||||
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
|
||||
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double expected = bayesNet.evaluate(x);
|
||||
double actual = bayesTree->evaluate(x);
|
||||
double expected = self.bayesNet.evaluate(x);
|
||||
double actual = self.bayesTree->evaluate(x);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
}
|
||||
|
||||
|
@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double px = bayesTree->evaluate(x);
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
if (x[12] && x[14]) {
|
||||
|
@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
DiscreteValues all1 = allPosbValues.back();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
auto clique = (*bayesTree)[0];
|
||||
auto clique = (*self.bayesTree)[0];
|
||||
DiscreteFactorGraph separatorMarginal0 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
clique = (*bayesTree)[11];
|
||||
clique = (*self.bayesTree)[11];
|
||||
DiscreteFactorGraph separatorMarginal11 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
LONGS_EQUAL(1, shortcut.size());
|
||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S8||R) to root
|
||||
clique = (*bayesTree)[8];
|
||||
clique = (*self.bayesTree)[8];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S2||R) to root
|
||||
clique = (*bayesTree)[2];
|
||||
clique = (*self.bayesTree)[2];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
clique = (*bayesTree)[0];
|
||||
clique = (*self.bayesTree)[0];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
|
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
|
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(1, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(2, 4)
|
||||
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 5)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 6)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 11)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, Dot) {
|
||||
const TestFixture self;
|
||||
string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13,11,6,7\"];\n"
|
||||
"0->1\n"
|
||||
"1[label=\"14 : 11,13\"];\n"
|
||||
"1->2\n"
|
||||
"2[label=\"9,12 : 14\"];\n"
|
||||
"2->3\n"
|
||||
"3[label=\"3 : 9,12\"];\n"
|
||||
"2->4\n"
|
||||
"4[label=\"2 : 9,12\"];\n"
|
||||
"2->5\n"
|
||||
"5[label=\"8 : 12,14\"];\n"
|
||||
"5->6\n"
|
||||
"6[label=\"1 : 8,12\"];\n"
|
||||
"5->7\n"
|
||||
"7[label=\"0 : 8,12\"];\n"
|
||||
"1->8\n"
|
||||
"8[label=\"10 : 13,14\"];\n"
|
||||
"8->9\n"
|
||||
"9[label=\"5 : 10,13\"];\n"
|
||||
"8->10\n"
|
||||
"10[label=\"4 : 10,13\"];\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -68,8 +68,7 @@ namespace gtsam {
|
|||
/// @{
|
||||
|
||||
/// Output to graphviz format, stream version.
|
||||
virtual void dot(std::ostream& os, const KeyFormatter& keyFormatter =
|
||||
DefaultKeyFormatter) const;
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
|
|
|
@ -63,20 +63,40 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
|
||||
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
|
||||
std::ofstream of(s.c_str());
|
||||
of<< "digraph G{\n";
|
||||
for(const sharedClique& root: roots_)
|
||||
saveGraph(of, root, keyFormatter);
|
||||
of<<"}";
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
if (roots_.empty())
|
||||
throw std::invalid_argument(
|
||||
"the root of Bayes tree has not been initialized!");
|
||||
os << "digraph G{\n";
|
||||
for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
|
||||
os << "}";
|
||||
std::flush(os);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, keyFormatter);
|
||||
of.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
|
||||
const KeyFormatter& indexFormatter,
|
||||
int parentnum) const {
|
||||
static int num = 0;
|
||||
bool first = true;
|
||||
std::stringstream out;
|
||||
|
@ -107,7 +127,7 @@ namespace gtsam {
|
|||
|
||||
for (sharedClique c : clique->children) {
|
||||
num++;
|
||||
saveGraph(s, c, indexFormatter, parentnum);
|
||||
dot(s, c, indexFormatter, parentnum);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -182,12 +182,16 @@ namespace gtsam {
|
|||
*/
|
||||
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
|
||||
|
||||
/**
|
||||
* Read only with side effects
|
||||
*/
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/** saves the Tree to a text file in GraphViz format */
|
||||
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
|
@ -236,8 +240,8 @@ namespace gtsam {
|
|||
protected:
|
||||
|
||||
/** private helper method for saving the Tree to a text file in GraphViz format */
|
||||
void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
|
||||
int parentnum = 0) const;
|
||||
void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
|
||||
int parentnum = 0) const;
|
||||
|
||||
/** Gather data on a single clique */
|
||||
void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const;
|
||||
|
@ -249,7 +253,7 @@ namespace gtsam {
|
|||
void fillNodesIndex(const sharedClique& subtree);
|
||||
|
||||
// Friend JunctionTree because it directly fills roots and nodes index.
|
||||
template<class BAYESRTEE, class GRAPH> friend class EliminatableClusterTree;
|
||||
template<class BAYESTREE, class GRAPH> friend class EliminatableClusterTree;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
|
|
Loading…
Reference in New Issue