diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index 0b1c69d50..bd90f4e4b 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -10,16 +10,16 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once -#include #include +#include #include #include @@ -29,23 +29,31 @@ namespace gtsam { /* ************************************************************************* */ template -void BayesNet::print( - const std::string& s, const KeyFormatter& formatter) const { +void BayesNet::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } /* ************************************************************************* */ template void BayesNet::dot(std::ostream& os, - const KeyFormatter& keyFormatter) const { - os << "digraph G{\n"; + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { + writer.digraphPreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : this->keys()) { + auto position = writer.variablePos(key); + writer.DrawVariable(key, keyFormatter, position, &os); + } + os << "\n"; for (auto conditional : boost::adaptors::reverse(*this)) { auto frontals = conditional->frontals(); const Key me = frontals.front(); auto parents = conditional->parents(); for (const Key& p : parents) - os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; + os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n"; } os << "}"; @@ -54,18 +62,20 @@ void BayesNet::dot(std::ostream& os, /* ************************************************************************* */ template -std::string BayesNet::dot(const KeyFormatter& keyFormatter) const { +std::string BayesNet::dot(const KeyFormatter& keyFormatter, + const DotWriter& writer) const { std::stringstream ss; - dot(ss, keyFormatter); + dot(ss, keyFormatter, writer); return ss.str(); } /* ************************************************************************* */ template void BayesNet::saveGraph(const std::string& filename, - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DotWriter& writer) const { std::ofstream of(filename.c_str()); - dot(of, keyFormatter); + dot(of, keyFormatter, writer); of.close(); } diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 6dfe60dfe..219864c54 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -10,77 +10,79 @@ * -------------------------------------------------------------------------- */ /** -* @file BayesNet.h -* @brief Bayes network -* @author Frank Dellaert -* @author Richard Roberts -*/ + * @file BayesNet.h + * @brief Bayes network + * @author Frank Dellaert + * @author Richard Roberts + */ #pragma once #include #include - #include namespace gtsam { - /** - * A BayesNet is a tree of conditionals, stored in elimination order. - * @addtogroup inference - */ - template - class BayesNet : public FactorGraph { +/** + * A BayesNet is a tree of conditionals, stored in elimination order. + * @addtogroup inference + */ +template +class BayesNet : public FactorGraph { + private: + typedef FactorGraph Base; - private: + public: + typedef typename boost::shared_ptr + sharedConditional; ///< A shared pointer to a conditional - typedef FactorGraph Base; + protected: + /// @name Standard Constructors + /// @{ - public: - typedef typename boost::shared_ptr sharedConditional; ///< A shared pointer to a conditional + /** Default constructor as an empty BayesNet */ + BayesNet() {} - protected: - /// @name Standard Constructors - /// @{ + /** Construct from iterator over conditionals */ + template + BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) + : Base(firstConditional, lastConditional) {} - /** Default constructor as an empty BayesNet */ - BayesNet() {}; + /// @} - /** Construct from iterator over conditionals */ - template - BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} + public: + /// @name Testable + /// @{ - /// @} + /** print out graph */ + void print( + const std::string& s = "BayesNet", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - public: - /// @name Testable - /// @{ + /// @} - /** print out graph */ - void print( - const std::string& s = "BayesNet", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /// @name Graph Display + /// @{ - /// @} + /// Output to graphviz format, stream version. + void dot(std::ostream& os, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// @name Graph Display - /// @{ + /// Output to graphviz format string. + std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// Output to graphviz format, stream version. - void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DotWriter& writer = DotWriter()) const; - /// Output to graphviz format string. - std::string dot( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} +}; - /// output to file with graphviz format. - void saveGraph(const std::string& filename, - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /// @} - }; - -} +} // namespace gtsam #include diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp index 18130c35d..a45482efb 100644 --- a/gtsam/inference/DotWriter.cpp +++ b/gtsam/inference/DotWriter.cpp @@ -25,12 +25,18 @@ using namespace std; namespace gtsam { -void DotWriter::writePreamble(ostream* os) const { +void DotWriter::graphPreamble(ostream* os) const { *os << "graph {\n"; *os << " size=\"" << figureWidthInches << "," << figureHeightInches << "\";\n\n"; } +void DotWriter::digraphPreamble(ostream* os) const { + *os << "digraph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, const boost::optional& position, ostream* os) { diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h index 93c229c2b..ad420b181 100644 --- a/gtsam/inference/DotWriter.h +++ b/gtsam/inference/DotWriter.h @@ -23,10 +23,14 @@ #include #include +#include namespace gtsam { -/// Graphviz formatter. +/** + * @brief DotWriter is a helper class for writing graphviz .dot files. + * @addtogroup inference + */ struct GTSAM_EXPORT DotWriter { double figureWidthInches; ///< The figure width on paper in inches double figureHeightInches; ///< The figure height on paper in inches @@ -35,6 +39,9 @@ struct GTSAM_EXPORT DotWriter { ///< the dot of the factor bool binaryEdges; ///< just use non-dotted edges for binary factors + /// (optional for each variable) Manually specify variable node positions + std::map variablePositions; + explicit DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, bool plotFactorPoints = true, @@ -45,8 +52,11 @@ struct GTSAM_EXPORT DotWriter { connectKeysToFactor(connectKeysToFactor), binaryEdges(binaryEdges) {} - /// Write out preamble, including size. - void writePreamble(std::ostream* os) const; + /// Write out preamble for graph, including size. + void graphPreamble(std::ostream* os) const; + + /// Write out preamble for digraph, including size. + void digraphPreamble(std::ostream* os) const; /// Create a variable dot fragment. static void DrawVariable(Key key, const KeyFormatter& keyFormatter, @@ -57,6 +67,15 @@ struct GTSAM_EXPORT DotWriter { static void DrawFactor(size_t i, const boost::optional& position, std::ostream* os); + /// Return variable position or none + boost::optional variablePos(Key key) const { + auto it = variablePositions.find(key); + if (it == variablePositions.end()) + return boost::none; + else + return it->second; + } + /// Draw a single factor, specified by its index i and its variable keys. void processFactor(size_t i, const KeyVector& keys, const KeyFormatter& keyFormatter, diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 3ea17fc7f..2034fdcb6 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -131,7 +131,7 @@ template void FactorGraph::dot(std::ostream& os, const KeyFormatter& keyFormatter, const DotWriter& writer) const { - writer.writePreamble(&os); + writer.graphPreamble(&os); // Create nodes for each variable in the graph for (Key key : keys()) { diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index 00a338e54..11fc7e7f7 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -301,5 +301,32 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) { } /* ************************************************************************* */ -int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +TEST(GaussianBayesNet, Dot) { + GaussianBayesNet fragment; + DotWriter writer; + writer.variablePositions.emplace(_x_, Vector2(10, 20)); + writer.variablePositions.emplace(_y_, Vector2(50, 20)); + + auto position = writer.variablePos(_x_); + CHECK(position); + EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5)); + + string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer); + noisyBayesNet.saveGraph("noisyBayesNet.dot", DefaultKeyFormatter, writer); + EXPECT(actual == + "digraph {\n" + " size=\"5,5\";\n" + "\n" + " var11[label=\"11\", pos=\"10,20!\"];\n" + " var22[label=\"22\", pos=\"50,20!\"];\n" + "\n" + " var22->var11\n" + "}"); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} /* ************************************************************************* */ diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp index e5b81c66b..1f0b3a875 100644 --- a/gtsam/nonlinear/GraphvizFormatting.cpp +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values, min.y() = std::numeric_limits::infinity(); for (const Key& key : keys) { if (values.exists(key)) { - boost::optional xy = operator()(values.at(key)); + boost::optional xy = extractPosition(values.at(key)); if (xy) { if (xy->x() < min.x()) min.x() = xy->x(); if (xy->y() < min.y()) min.y() = xy->y(); @@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values, return min; } -boost::optional GraphvizFormatting::operator()( +boost::optional GraphvizFormatting::extractPosition( const Value& value) const { Vector3 t; if (const GenericValue* p = @@ -121,12 +121,11 @@ boost::optional GraphvizFormatting::operator()( return Vector2(x, y); } -// Return affinely transformed variable position if it exists. boost::optional GraphvizFormatting::variablePos(const Values& values, const Vector2& min, Key key) const { if (!values.exists(key)) return boost::none; - boost::optional xy = operator()(values.at(key)); + boost::optional xy = extractPosition(values.at(key)); if (xy) { xy->x() = scale * (xy->x() - min.x()); xy->y() = scale * (xy->y() - min.y()); @@ -134,7 +133,6 @@ boost::optional GraphvizFormatting::variablePos(const Values& values, return xy; } -// Return affinely transformed factor position if it exists. boost::optional GraphvizFormatting::factorPos(const Vector2& min, size_t i) const { if (factorPositions.size() == 0) return boost::none; diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h index c36b09a8f..d71e73f31 100644 --- a/gtsam/nonlinear/GraphvizFormatting.h +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -33,10 +33,10 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { /// World axes to be assigned to paper axes enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal - ///< paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper - ///< axis + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis double scale; ///< Scale all positions to reduce / increase density bool mergeSimilarFactors; ///< Merge multiple factors that have the same ///< connectivity @@ -55,8 +55,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { // Find bounds Vector2 findBounds(const Values& values, const KeySet& keys) const; - /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 - boost::optional operator()(const Value& value) const; + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional extractPosition(const Value& value) const; /// Return affinely transformed variable position if it exists. boost::optional variablePos(const Values& values, const Vector2& min, diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index da8935d5f..c03caed75 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, const KeyFormatter& keyFormatter, const GraphvizFormatting& writer) const { - writer.writePreamble(&os); + writer.graphPreamble(&os); // Find bounds (imperative) KeySet keys = this->keys(); diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index ea8748f63..6f083a323 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -215,20 +215,19 @@ namespace gtsam { /// Output to graphviz format, stream version, with Values/extra options. void dot(std::ostream& os, const Values& values, const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// Output to graphviz format string, with Values/extra options. - std::string dot(const Values& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + std::string dot( + const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// output to file with graphviz format, with Values/extra options. - void saveGraph(const std::string& filename, const Values& values, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const GraphvizFormatting& graphvizFormatting = - GraphvizFormatting()) const; + void saveGraph( + const std::string& filename, const Values& values, + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const GraphvizFormatting& writer = GraphvizFormatting()) const; /// @} private: @@ -262,16 +261,16 @@ namespace gtsam { {return updateCholesky(values, dampen);} /** @deprecated */ - void GTSAM_DEPRECATED saveGraph( - std::ostream& os, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + void GTSAM_DEPRECATED + saveGraph(std::ostream& os, const Values& values = Values(), + const GraphvizFormatting& writer = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { dot(os, values, keyFormatter, graphvizFormatting); } /** @deprecated */ void GTSAM_DEPRECATED saveGraph(const std::string& filename, const Values& values, - const GraphvizFormatting& graphvizFormatting, + const GraphvizFormatting& writer, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { saveGraph(filename, values, keyFormatter, graphvizFormatting); }