Variable positions for Bayes nets

release/4.3a0
Frank Dellaert 2022-01-26 18:46:06 -05:00
parent 62b188473b
commit ebe3aadada
10 changed files with 157 additions and 96 deletions

View File

@ -18,8 +18,8 @@
#pragma once
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
@ -29,23 +29,31 @@ namespace gtsam {
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
void BayesNet<CONDITIONAL>::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.digraphPreamble(&os);
// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.DrawVariable(key, keyFormatter, position, &os);
}
os << "\n";
for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
}
os << "}";
@ -54,18 +62,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss;
dot(ss, keyFormatter);
dot(ss, keyFormatter, writer);
return ss.str();
}
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
dot(of, keyFormatter, writer);
of.close();
}

View File

@ -21,7 +21,6 @@
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
namespace gtsam {
@ -32,24 +31,24 @@ namespace gtsam {
*/
template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;
public:
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional
protected:
/// @name Standard Constructors
/// @{
/** Default constructor as an empty BayesNet */
BayesNet() {};
BayesNet() {}
/** Construct from iterator over conditionals */
template <typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/// @}
@ -68,19 +67,22 @@ namespace gtsam {
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// @}
};
}
} // namespace gtsam
#include <gtsam/inference/BayesNet-inst.h>

View File

@ -25,12 +25,18 @@ using namespace std;
namespace gtsam {
void DotWriter::writePreamble(ostream* os) const {
void DotWriter::graphPreamble(ostream* os) const {
*os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::digraphPreamble(ostream* os) const {
*os << "digraph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
ostream* os) {

View File

@ -23,10 +23,14 @@
#include <gtsam/inference/Key.h>
#include <iosfwd>
#include <map>
namespace gtsam {
/// Graphviz formatter.
/**
* @brief DotWriter is a helper class for writing graphviz .dot files.
* @addtogroup inference
*/
struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches
@ -35,6 +39,9 @@ struct GTSAM_EXPORT DotWriter {
///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors
/// (optional for each variable) Manually specify variable node positions
std::map<gtsam::Key, Vector2> variablePositions;
explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5,
bool plotFactorPoints = true,
@ -45,8 +52,11 @@ struct GTSAM_EXPORT DotWriter {
connectKeysToFactor(connectKeysToFactor),
binaryEdges(binaryEdges) {}
/// Write out preamble, including size.
void writePreamble(std::ostream* os) const;
/// Write out preamble for graph, including size.
void graphPreamble(std::ostream* os) const;
/// Write out preamble for digraph, including size.
void digraphPreamble(std::ostream* os) const;
/// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
@ -57,6 +67,15 @@ struct GTSAM_EXPORT DotWriter {
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os);
/// Return variable position or none
boost::optional<Vector2> variablePos(Key key) const {
auto it = variablePositions.find(key);
if (it == variablePositions.end())
return boost::none;
else
return it->second;
}
/// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,

View File

@ -131,7 +131,7 @@ template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.writePreamble(&os);
writer.graphPreamble(&os);
// Create nodes for each variable in the graph
for (Key key : keys()) {

View File

@ -301,5 +301,32 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
TEST(GaussianBayesNet, Dot) {
GaussianBayesNet fragment;
DotWriter writer;
writer.variablePositions.emplace(_x_, Vector2(10, 20));
writer.variablePositions.emplace(_y_, Vector2(50, 20));
auto position = writer.variablePos(_x_);
CHECK(position);
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
noisyBayesNet.saveGraph("noisyBayesNet.dot", DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var11[label=\"11\", pos=\"10,20!\"];\n"
" var22[label=\"22\", pos=\"50,20!\"];\n"
"\n"
" var22->var11\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) {
if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key));
boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) {
if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y();
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
return min;
}
boost::optional<Vector2> GraphvizFormatting::operator()(
boost::optional<Vector2> GraphvizFormatting::extractPosition(
const Value& value) const {
Vector3 t;
if (const GenericValue<Pose2>* p =
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
return Vector2(x, y);
}
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min,
Key key) const {
if (!values.exists(key)) return boost::none;
boost::optional<Vector2> xy = operator()(values.at(key));
boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) {
xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y());
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
return xy;
}
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const {
if (factorPositions.size() == 0) return boost::none;

View File

@ -56,7 +56,7 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const;
boost::optional<Vector2> extractPosition(const Value& value) const;
/// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,

View File

@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter,
const GraphvizFormatting& writer) const {
writer.writePreamble(&os);
writer.graphPreamble(&os);
// Find bounds (imperative)
KeySet keys = this->keys();

View File

@ -215,20 +215,19 @@ namespace gtsam {
/// Output to graphviz format, stream version, with Values/extra options.
void dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// Output to graphviz format string, with Values/extra options.
std::string dot(const Values& values,
std::string dot(
const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// output to file with graphviz format, with Values/extra options.
void saveGraph(const std::string& filename, const Values& values,
void saveGraph(
const std::string& filename, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// @}
private:
@ -262,16 +261,16 @@ namespace gtsam {
{return updateCholesky(values, dampen);}
/** @deprecated */
void GTSAM_DEPRECATED saveGraph(
std::ostream& os, const Values& values = Values(),
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
void GTSAM_DEPRECATED
saveGraph(std::ostream& os, const Values& values = Values(),
const GraphvizFormatting& writer = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
dot(os, values, keyFormatter, graphvizFormatting);
}
/** @deprecated */
void GTSAM_DEPRECATED
saveGraph(const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting,
const GraphvizFormatting& writer,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
saveGraph(filename, values, keyFormatter, graphvizFormatting);
}