From ea2c13bca3c0c1a713c9b03892a5953efec44816 Mon Sep 17 00:00:00 2001 From: jdurand7 Date: Fri, 14 Sep 2012 22:13:33 +0000 Subject: [PATCH] Added method saveGraph for BayesNet. --- gtsam.h | 1 + gtsam/inference/BayesNet-inl.h | 31 ++++++++++++++++--- gtsam/inference/BayesNet.h | 4 +++ .../inference/tests/testSymbolicBayesNet.cpp | 14 +++++++++ 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/gtsam.h b/gtsam.h index 250d64f0b..6777106b8 100644 --- a/gtsam.h +++ b/gtsam.h @@ -779,6 +779,7 @@ virtual class BayesNet { // Standard interface size_t size() const; void printStats(string s) const; + void saveGraph(string s) const; CONDITIONAL* front() const; CONDITIONAL* back() const; void push_back(CONDITIONAL* conditional); diff --git a/gtsam/inference/BayesNet-inl.h b/gtsam/inference/BayesNet-inl.h index 912e49ef4..e113b7628 100644 --- a/gtsam/inference/BayesNet-inl.h +++ b/gtsam/inference/BayesNet-inl.h @@ -41,6 +41,15 @@ namespace gtsam { conditional->print("Conditional", formatter); } + /* ************************************************************************* */ + template + bool BayesNet::equals(const BayesNet& cbn, double tol) const { + if (size() != cbn.size()) + return false; + return std::equal(conditionals_.begin(), conditionals_.end(), + cbn.conditionals_.begin(), equals_star(tol)); + } + /* ************************************************************************* */ template void BayesNet::printStats(const std::string& s) const { @@ -60,11 +69,23 @@ namespace gtsam { /* ************************************************************************* */ template - bool BayesNet::equals(const BayesNet& cbn, double tol) const { - if (size() != cbn.size()) - return false; - return std::equal(conditionals_.begin(), conditionals_.end(), - cbn.conditionals_.begin(), equals_star(tol)); + void BayesNet::saveGraph(const std::string &s, + const IndexFormatter& indexFormatter) const { + std::ofstream of(s.c_str()); + of << "digraph G{\n"; + + BOOST_FOREACH(typename CONDITIONAL::shared_ptr conditional, conditionals_) { + typename CONDITIONAL::Frontals frontals = conditional->frontals(); + Index me = frontals.front(); +// of << me << std::endl; + typename CONDITIONAL::Parents parents = conditional->parents(); + BOOST_FOREACH(Index p, parents) + of << p << "->" << me << std::endl; + } + + + of << "}"; + of.close(); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 4fa10eee2..0bb5efedb 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -108,6 +108,10 @@ public: /** print statistics */ void printStats(const std::string& s = "") const; + /** save dot graph */ + void saveGraph(const std::string &s, const IndexFormatter& indexFormatter = + DefaultIndexFormatter) const; + /** return keys in reverse topological sort order, i.e., elimination order */ FastList ordering() const; diff --git a/gtsam/inference/tests/testSymbolicBayesNet.cpp b/gtsam/inference/tests/testSymbolicBayesNet.cpp index 06b1bded9..445bab697 100644 --- a/gtsam/inference/tests/testSymbolicBayesNet.cpp +++ b/gtsam/inference/tests/testSymbolicBayesNet.cpp @@ -188,6 +188,20 @@ TEST_UNSAFE(SymbolicBayesNet, popLeaf) { //#endif } +/* ************************************************************************* */ +TEST(SymbolicBayesNet, saveGraph) { + SymbolicBayesNet bn; + bn += IndexConditional::shared_ptr(new IndexConditional(_A_, _B_)); + std::vector keys; + keys.push_back(_B_); + keys.push_back(_C_); + keys.push_back(_D_); + bn += IndexConditional::shared_ptr(new IndexConditional(keys,2)); + bn += IndexConditional::shared_ptr(new IndexConditional(_D_)); + + bn.saveGraph("SymbolicBayesNet.dot"); +} + /* ************************************************************************* */ int main() { TestResult tr;