Added method saveGraph for BayesNet.

release/4.3a0
jdurand7 2012-09-14 22:13:33 +00:00
parent 090133f944
commit ea2c13bca3
4 changed files with 45 additions and 5 deletions

View File

@ -779,6 +779,7 @@ virtual class BayesNet {
// Standard interface
size_t size() const;
void printStats(string s) const;
void saveGraph(string s) const;
CONDITIONAL* front() const;
CONDITIONAL* back() const;
void push_back(CONDITIONAL* conditional);

View File

@ -41,6 +41,15 @@ namespace gtsam {
conditional->print("Conditional", formatter);
}
/* ************************************************************************* */
template<class CONDITIONAL>
bool BayesNet<CONDITIONAL>::equals(const BayesNet& cbn, double tol) const {
if (size() != cbn.size())
return false;
return std::equal(conditionals_.begin(), conditionals_.end(),
cbn.conditionals_.begin(), equals_star<CONDITIONAL>(tol));
}
/* ************************************************************************* */
template<class CONDITIONAL>
void BayesNet<CONDITIONAL>::printStats(const std::string& s) const {
@ -60,11 +69,23 @@ namespace gtsam {
/* ************************************************************************* */
template<class CONDITIONAL>
bool BayesNet<CONDITIONAL>::equals(const BayesNet& cbn, double tol) const {
if (size() != cbn.size())
return false;
return std::equal(conditionals_.begin(), conditionals_.end(),
cbn.conditionals_.begin(), equals_star<CONDITIONAL>(tol));
void BayesNet<CONDITIONAL>::saveGraph(const std::string &s,
const IndexFormatter& indexFormatter) const {
std::ofstream of(s.c_str());
of << "digraph G{\n";
BOOST_FOREACH(typename CONDITIONAL::shared_ptr conditional, conditionals_) {
typename CONDITIONAL::Frontals frontals = conditional->frontals();
Index me = frontals.front();
// of << me << std::endl;
typename CONDITIONAL::Parents parents = conditional->parents();
BOOST_FOREACH(Index p, parents)
of << p << "->" << me << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */

View File

@ -108,6 +108,10 @@ public:
/** print statistics */
void printStats(const std::string& s = "") const;
/** save dot graph */
void saveGraph(const std::string &s, const IndexFormatter& indexFormatter =
DefaultIndexFormatter) const;
/** return keys in reverse topological sort order, i.e., elimination order */
FastList<Index> ordering() const;

View File

@ -188,6 +188,20 @@ TEST_UNSAFE(SymbolicBayesNet, popLeaf) {
//#endif
}
/* ************************************************************************* */
TEST(SymbolicBayesNet, saveGraph) {
SymbolicBayesNet bn;
bn += IndexConditional::shared_ptr(new IndexConditional(_A_, _B_));
std::vector<Index> keys;
keys.push_back(_B_);
keys.push_back(_C_);
keys.push_back(_D_);
bn += IndexConditional::shared_ptr(new IndexConditional(keys,2));
bn += IndexConditional::shared_ptr(new IndexConditional(_D_));
bn.saveGraph("SymbolicBayesNet.dot");
}
/* ************************************************************************* */
int main() {
TestResult tr;