Variable positions for Bayes nets

release/4.3a0
Frank Dellaert 2022-01-26 18:46:06 -05:00
parent 62b188473b
commit ebe3aadada
10 changed files with 157 additions and 96 deletions

View File

@ -10,16 +10,16 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp> #include <boost/range/adaptor/reversed.hpp>
#include <fstream> #include <fstream>
@ -29,23 +29,31 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print( void BayesNet<CONDITIONAL>::print(const std::string& s,
const std::string& s, const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
Base::print(s, formatter); Base::print(s, formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os, void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
os << "digraph G{\n"; const DotWriter& writer) const {
writer.digraphPreamble(&os);
// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.DrawVariable(key, keyFormatter, position, &os);
}
os << "\n";
for (auto conditional : boost::adaptors::reverse(*this)) { for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals(); auto frontals = conditional->frontals();
const Key me = frontals.front(); const Key me = frontals.front();
auto parents = conditional->parents(); auto parents = conditional->parents();
for (const Key& p : parents) for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
} }
os << "}"; os << "}";
@ -54,18 +62,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const { std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss; std::stringstream ss;
dot(ss, keyFormatter); dot(ss, keyFormatter, writer);
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename, void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str()); std::ofstream of(filename.c_str());
dot(of, keyFormatter); dot(of, keyFormatter, writer);
of.close(); of.close();
} }

View File

@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <string> #include <string>
namespace gtsam { namespace gtsam {
/** /**
* A BayesNet is a tree of conditionals, stored in elimination order. * A BayesNet is a tree of conditionals, stored in elimination order.
* @addtogroup inference * @addtogroup inference
*/ */
template<class CONDITIONAL> template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> { class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;
private: public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional
typedef FactorGraph<CONDITIONAL> Base; protected:
/// @name Standard Constructors
/// @{
public: /** Default constructor as an empty BayesNet */
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional BayesNet() {}
protected: /** Construct from iterator over conditionals */
/// @name Standard Constructors template <typename ITERATOR>
/// @{ BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Default constructor as an empty BayesNet */ /// @}
BayesNet() {};
/** Construct from iterator over conditionals */ public:
template<typename ITERATOR> /// @name Testable
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} /// @{
/// @} /** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
public: /// @}
/// @name Testable
/// @{
/** print out graph */ /// @name Graph Display
void print( /// @{
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// @name Graph Display /// Output to graphviz format string.
/// @{ std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format, stream version. /// output to file with graphviz format.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format string. /// @}
std::string dot( };
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format. } // namespace gtsam
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
}
#include <gtsam/inference/BayesNet-inst.h> #include <gtsam/inference/BayesNet-inst.h>

View File

@ -25,12 +25,18 @@ using namespace std;
namespace gtsam { namespace gtsam {
void DotWriter::writePreamble(ostream* os) const { void DotWriter::graphPreamble(ostream* os) const {
*os << "graph {\n"; *os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches *os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n"; << "\";\n\n";
} }
void DotWriter::digraphPreamble(ostream* os) const {
*os << "digraph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
ostream* os) { ostream* os) {

View File

@ -23,10 +23,14 @@
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <iosfwd> #include <iosfwd>
#include <map>
namespace gtsam { namespace gtsam {
/// Graphviz formatter. /**
* @brief DotWriter is a helper class for writing graphviz .dot files.
* @addtogroup inference
*/
struct GTSAM_EXPORT DotWriter { struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches double figureHeightInches; ///< The figure height on paper in inches
@ -35,6 +39,9 @@ struct GTSAM_EXPORT DotWriter {
///< the dot of the factor ///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors bool binaryEdges; ///< just use non-dotted edges for binary factors
/// (optional for each variable) Manually specify variable node positions
std::map<gtsam::Key, Vector2> variablePositions;
explicit DotWriter(double figureWidthInches = 5, explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool plotFactorPoints = true,
@ -45,8 +52,11 @@ struct GTSAM_EXPORT DotWriter {
connectKeysToFactor(connectKeysToFactor), connectKeysToFactor(connectKeysToFactor),
binaryEdges(binaryEdges) {} binaryEdges(binaryEdges) {}
/// Write out preamble, including size. /// Write out preamble for graph, including size.
void writePreamble(std::ostream* os) const; void graphPreamble(std::ostream* os) const;
/// Write out preamble for digraph, including size.
void digraphPreamble(std::ostream* os) const;
/// Create a variable dot fragment. /// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter, static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
@ -57,6 +67,15 @@ struct GTSAM_EXPORT DotWriter {
static void DrawFactor(size_t i, const boost::optional<Vector2>& position, static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os); std::ostream* os);
/// Return variable position or none
boost::optional<Vector2> variablePos(Key key) const {
auto it = variablePositions.find(key);
if (it == variablePositions.end())
return boost::none;
else
return it->second;
}
/// Draw a single factor, specified by its index i and its variable keys. /// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys, void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,

View File

@ -131,7 +131,7 @@ template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os, void FactorGraph<FACTOR>::dot(std::ostream& os,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const DotWriter& writer) const { const DotWriter& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Create nodes for each variable in the graph // Create nodes for each variable in the graph
for (Key key : keys()) { for (Key key : keys()) {

View File

@ -301,5 +301,32 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);} TEST(GaussianBayesNet, Dot) {
GaussianBayesNet fragment;
DotWriter writer;
writer.variablePositions.emplace(_x_, Vector2(10, 20));
writer.variablePositions.emplace(_y_, Vector2(50, 20));
auto position = writer.variablePos(_x_);
CHECK(position);
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
noisyBayesNet.saveGraph("noisyBayesNet.dot", DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var11[label=\"11\", pos=\"10,20!\"];\n"
" var22[label=\"22\", pos=\"50,20!\"];\n"
"\n"
" var22->var11\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
min.y() = std::numeric_limits<double>::infinity(); min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) { for (const Key& key : keys) {
if (values.exists(key)) { if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
if (xy->x() < min.x()) min.x() = xy->x(); if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y(); if (xy->y() < min.y()) min.y() = xy->y();
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
return min; return min;
} }
boost::optional<Vector2> GraphvizFormatting::operator()( boost::optional<Vector2> GraphvizFormatting::extractPosition(
const Value& value) const { const Value& value) const {
Vector3 t; Vector3 t;
if (const GenericValue<Pose2>* p = if (const GenericValue<Pose2>* p =
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
return Vector2(x, y); return Vector2(x, y);
} }
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values, boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min, const Vector2& min,
Key key) const { Key key) const {
if (!values.exists(key)) return boost::none; if (!values.exists(key)) return boost::none;
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
xy->x() = scale * (xy->x() - min.x()); xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y()); xy->y() = scale * (xy->y() - min.y());
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
return xy; return xy;
} }
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min, boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const { size_t i) const {
if (factorPositions.size() == 0) return boost::none; if (factorPositions.size() == 0) return boost::none;

View File

@ -33,10 +33,10 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
/// World axes to be assigned to paper axes /// World axes to be assigned to paper axes
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis ///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis ///< axis
double scale; ///< Scale all positions to reduce / increase density double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same bool mergeSimilarFactors; ///< Merge multiple factors that have the same
///< connectivity ///< connectivity
@ -55,8 +55,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
// Find bounds // Find bounds
Vector2 findBounds(const Values& values, const KeySet& keys) const; Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const; boost::optional<Vector2> extractPosition(const Value& value) const;
/// Return affinely transformed variable position if it exists. /// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min, boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,

View File

@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const GraphvizFormatting& writer) const { const GraphvizFormatting& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Find bounds (imperative) // Find bounds (imperative)
KeySet keys = this->keys(); KeySet keys = this->keys();

View File

@ -215,20 +215,19 @@ namespace gtsam {
/// Output to graphviz format, stream version, with Values/extra options. /// Output to graphviz format, stream version, with Values/extra options.
void dot(std::ostream& os, const Values& values, void dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting = const GraphvizFormatting& writer = GraphvizFormatting()) const;
GraphvizFormatting()) const;
/// Output to graphviz format string, with Values/extra options. /// Output to graphviz format string, with Values/extra options.
std::string dot(const Values& values, std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// output to file with graphviz format, with Values/extra options. /// output to file with graphviz format, with Values/extra options.
void saveGraph(const std::string& filename, const Values& values, void saveGraph(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// @} /// @}
private: private:
@ -262,16 +261,16 @@ namespace gtsam {
{return updateCholesky(values, dampen);} {return updateCholesky(values, dampen);}
/** @deprecated */ /** @deprecated */
void GTSAM_DEPRECATED saveGraph( void GTSAM_DEPRECATED
std::ostream& os, const Values& values = Values(), saveGraph(std::ostream& os, const Values& values = Values(),
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), const GraphvizFormatting& writer = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
dot(os, values, keyFormatter, graphvizFormatting); dot(os, values, keyFormatter, graphvizFormatting);
} }
/** @deprecated */ /** @deprecated */
void GTSAM_DEPRECATED void GTSAM_DEPRECATED
saveGraph(const std::string& filename, const Values& values, saveGraph(const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting, const GraphvizFormatting& writer,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
saveGraph(filename, values, keyFormatter, graphvizFormatting); saveGraph(filename, values, keyFormatter, graphvizFormatting);
} }