diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index a73762258..4674b0083 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -35,21 +35,39 @@ void BayesNet::print( /* ************************************************************************* */ template -void BayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; +void BayesNet::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + os << "digraph G{\n"; for (auto conditional : boost::adaptors::reverse(*this)) { typename CONDITIONAL::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename CONDITIONAL::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; + const Key me = frontals.front(); + auto parents = conditional->parents(); + for (const Key& p : parents) + os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; } - of << "}"; + os << "}"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string BayesNet::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); +} + +/* ************************************************************************* */ +template +void BayesNet::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } +/* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 938278d5a..e430b3fe4 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -67,8 +67,19 @@ namespace gtsam { /// @name Standard Interface /// @{ - void saveGraph(const std::string& s, + /// Output to graphviz format, stream version. + virtual void dot(std::ostream& os, const KeyFormatter& keyFormatter = + DefaultKeyFormatter) const; + + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; }