Variable positions for Bayes nets
parent
62b188473b
commit
ebe3aadada
|
@ -18,8 +18,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
#include <boost/range/adaptor/reversed.hpp>
|
||||
#include <fstream>
|
||||
|
@ -29,23 +29,31 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::print(
|
||||
const std::string& s, const KeyFormatter& formatter) const {
|
||||
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
os << "digraph G{\n";
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
writer.digraphPreamble(&os);
|
||||
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : this->keys()) {
|
||||
auto position = writer.variablePos(key);
|
||||
writer.DrawVariable(key, keyFormatter, position, &os);
|
||||
}
|
||||
os << "\n";
|
||||
|
||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||
auto frontals = conditional->frontals();
|
||||
const Key me = frontals.front();
|
||||
auto parents = conditional->parents();
|
||||
for (const Key& p : parents)
|
||||
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
|
||||
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
|
||||
}
|
||||
|
||||
os << "}";
|
||||
|
@ -54,18 +62,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
|
||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, keyFormatter);
|
||||
dot(ss, keyFormatter, writer);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, keyFormatter);
|
||||
dot(of, keyFormatter, writer);
|
||||
of.close();
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -32,24 +31,24 @@ namespace gtsam {
|
|||
*/
|
||||
template <class CONDITIONAL>
|
||||
class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||
|
||||
private:
|
||||
|
||||
typedef FactorGraph<CONDITIONAL> Base;
|
||||
|
||||
public:
|
||||
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
|
||||
typedef typename boost::shared_ptr<CONDITIONAL>
|
||||
sharedConditional; ///< A shared pointer to a conditional
|
||||
|
||||
protected:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor as an empty BayesNet */
|
||||
BayesNet() {};
|
||||
BayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template <typename ITERATOR>
|
||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/// @}
|
||||
|
||||
|
@ -68,19 +67,22 @@ namespace gtsam {
|
|||
/// @{
|
||||
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
void dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
||||
#include <gtsam/inference/BayesNet-inst.h>
|
||||
|
|
|
@ -25,12 +25,18 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
void DotWriter::writePreamble(ostream* os) const {
|
||||
void DotWriter::graphPreamble(ostream* os) const {
|
||||
*os << "graph {\n";
|
||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||
<< "\";\n\n";
|
||||
}
|
||||
|
||||
void DotWriter::digraphPreamble(ostream* os) const {
|
||||
*os << "digraph {\n";
|
||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||
<< "\";\n\n";
|
||||
}
|
||||
|
||||
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
ostream* os) {
|
||||
|
|
|
@ -23,10 +23,14 @@
|
|||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <iosfwd>
|
||||
#include <map>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Graphviz formatter.
|
||||
/**
|
||||
* @brief DotWriter is a helper class for writing graphviz .dot files.
|
||||
* @addtogroup inference
|
||||
*/
|
||||
struct GTSAM_EXPORT DotWriter {
|
||||
double figureWidthInches; ///< The figure width on paper in inches
|
||||
double figureHeightInches; ///< The figure height on paper in inches
|
||||
|
@ -35,6 +39,9 @@ struct GTSAM_EXPORT DotWriter {
|
|||
///< the dot of the factor
|
||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||
|
||||
/// (optional for each variable) Manually specify variable node positions
|
||||
std::map<gtsam::Key, Vector2> variablePositions;
|
||||
|
||||
explicit DotWriter(double figureWidthInches = 5,
|
||||
double figureHeightInches = 5,
|
||||
bool plotFactorPoints = true,
|
||||
|
@ -45,8 +52,11 @@ struct GTSAM_EXPORT DotWriter {
|
|||
connectKeysToFactor(connectKeysToFactor),
|
||||
binaryEdges(binaryEdges) {}
|
||||
|
||||
/// Write out preamble, including size.
|
||||
void writePreamble(std::ostream* os) const;
|
||||
/// Write out preamble for graph, including size.
|
||||
void graphPreamble(std::ostream* os) const;
|
||||
|
||||
/// Write out preamble for digraph, including size.
|
||||
void digraphPreamble(std::ostream* os) const;
|
||||
|
||||
/// Create a variable dot fragment.
|
||||
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
|
@ -57,6 +67,15 @@ struct GTSAM_EXPORT DotWriter {
|
|||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||
std::ostream* os);
|
||||
|
||||
/// Return variable position or none
|
||||
boost::optional<Vector2> variablePos(Key key) const {
|
||||
auto it = variablePositions.find(key);
|
||||
if (it == variablePositions.end())
|
||||
return boost::none;
|
||||
else
|
||||
return it->second;
|
||||
}
|
||||
|
||||
/// Draw a single factor, specified by its index i and its variable keys.
|
||||
void processFactor(size_t i, const KeyVector& keys,
|
||||
const KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -131,7 +131,7 @@ template <class FACTOR>
|
|||
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
writer.writePreamble(&os);
|
||||
writer.graphPreamble(&os);
|
||||
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : keys()) {
|
||||
|
|
|
@ -301,5 +301,32 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||
TEST(GaussianBayesNet, Dot) {
|
||||
GaussianBayesNet fragment;
|
||||
DotWriter writer;
|
||||
writer.variablePositions.emplace(_x_, Vector2(10, 20));
|
||||
writer.variablePositions.emplace(_y_, Vector2(50, 20));
|
||||
|
||||
auto position = writer.variablePos(_x_);
|
||||
CHECK(position);
|
||||
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
||||
|
||||
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
||||
noisyBayesNet.saveGraph("noisyBayesNet.dot", DefaultKeyFormatter, writer);
|
||||
EXPECT(actual ==
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var11[label=\"11\", pos=\"10,20!\"];\n"
|
||||
" var22[label=\"22\", pos=\"50,20!\"];\n"
|
||||
"\n"
|
||||
" var22->var11\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
|||
min.y() = std::numeric_limits<double>::infinity();
|
||||
for (const Key& key : keys) {
|
||||
if (values.exists(key)) {
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||
if (xy) {
|
||||
if (xy->x() < min.x()) min.x() = xy->x();
|
||||
if (xy->y() < min.y()) min.y() = xy->y();
|
||||
|
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
|||
return min;
|
||||
}
|
||||
|
||||
boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||
boost::optional<Vector2> GraphvizFormatting::extractPosition(
|
||||
const Value& value) const {
|
||||
Vector3 t;
|
||||
if (const GenericValue<Pose2>* p =
|
||||
|
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
|
|||
return Vector2(x, y);
|
||||
}
|
||||
|
||||
// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||
const Vector2& min,
|
||||
Key key) const {
|
||||
if (!values.exists(key)) return boost::none;
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||
if (xy) {
|
||||
xy->x() = scale * (xy->x() - min.x());
|
||||
xy->y() = scale * (xy->y() - min.y());
|
||||
|
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
|||
return xy;
|
||||
}
|
||||
|
||||
// Return affinely transformed factor position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
||||
size_t i) const {
|
||||
if (factorPositions.size() == 0) return boost::none;
|
||||
|
|
|
@ -56,7 +56,7 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
|||
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
||||
|
||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||
boost::optional<Vector2> operator()(const Value& value) const;
|
||||
boost::optional<Vector2> extractPosition(const Value& value) const;
|
||||
|
||||
/// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
||||
|
|
|
@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
|
|||
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||
const KeyFormatter& keyFormatter,
|
||||
const GraphvizFormatting& writer) const {
|
||||
writer.writePreamble(&os);
|
||||
writer.graphPreamble(&os);
|
||||
|
||||
// Find bounds (imperative)
|
||||
KeySet keys = this->keys();
|
||||
|
|
|
@ -215,20 +215,19 @@ namespace gtsam {
|
|||
/// Output to graphviz format, stream version, with Values/extra options.
|
||||
void dot(std::ostream& os, const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
|
||||
/// Output to graphviz format string, with Values/extra options.
|
||||
std::string dot(const Values& values,
|
||||
std::string dot(
|
||||
const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
|
||||
/// output to file with graphviz format, with Values/extra options.
|
||||
void saveGraph(const std::string& filename, const Values& values,
|
||||
void saveGraph(
|
||||
const std::string& filename, const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
@ -262,16 +261,16 @@ namespace gtsam {
|
|||
{return updateCholesky(values, dampen);}
|
||||
|
||||
/** @deprecated */
|
||||
void GTSAM_DEPRECATED saveGraph(
|
||||
std::ostream& os, const Values& values = Values(),
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
void GTSAM_DEPRECATED
|
||||
saveGraph(std::ostream& os, const Values& values = Values(),
|
||||
const GraphvizFormatting& writer = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
dot(os, values, keyFormatter, graphvizFormatting);
|
||||
}
|
||||
/** @deprecated */
|
||||
void GTSAM_DEPRECATED
|
||||
saveGraph(const std::string& filename, const Values& values,
|
||||
const GraphvizFormatting& graphvizFormatting,
|
||||
const GraphvizFormatting& writer,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue