Merge pull request #1070 from borglab/feauture/dotwriter

release/4.3a0
Frank Dellaert 2022-01-27 21:44:45 -05:00 committed by GitHub
commit 84aed900c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 450 additions and 291 deletions

View File

@ -31,11 +31,12 @@
namespace gtsam {
/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:
/**
* A Bayes net made from discrete conditional distributions.
* @addtogroup discrete
*/
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
@ -49,16 +50,20 @@ namespace gtsam {
DiscreteBayesNet() {}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor
virtual ~DiscreteBayesNet() {}

View File

@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
void printSignature(
@ -156,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -252,14 +257,6 @@ class DiscreteFactorGraph {
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
@ -281,6 +278,14 @@ class DiscreteFactorGraph {
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,

View File

@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}");
}

View File

@ -10,41 +10,51 @@
* -------------------------------------------------------------------------- */
/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
#include <string>
namespace gtsam {
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print(
const std::string& s, const KeyFormatter& formatter) const {
void BayesNet<CONDITIONAL>::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.digraphPreamble(&os);
for (auto conditional : *this) {
// Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";
// Reverse order as typically Bayes nets stored in reverse topological sort.
for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
}
os << "}";
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss;
dot(ss, keyFormatter);
dot(ss, keyFormatter, writer);
return ss.str();
}
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
dot(of, keyFormatter, writer);
of.close();
}

View File

@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */
/**
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
* @file BayesNet.h
* @brief Bayes network
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
namespace gtsam {
/**
* A BayesNet is a tree of conditionals, stored in elimination order.
*
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
* \nosubgrouping
*/
template<class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
/**
* A BayesNet is a tree of conditionals, stored in elimination order.
* @addtogroup inference
*/
template <class CONDITIONAL>
class BayesNet : public FactorGraph<CONDITIONAL> {
private:
typedef FactorGraph<CONDITIONAL> Base;
private:
public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional
typedef FactorGraph<CONDITIONAL> Base;
protected:
/// @name Standard Constructors
/// @{
public:
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
/** Default constructor as an empty BayesNet */
BayesNet() {}
protected:
/// @name Standard Constructors
/// @{
/** Construct from iterator over conditionals */
template <typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Default constructor as an empty BayesNet */
BayesNet() {};
/// @}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
public:
/// @name Testable
/// @{
/// @}
/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
public:
/// @name Testable
/// @{
/// @}
/** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @name Graph Display
/// @{
/// @}
/// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// @name Graph Display
/// @{
/// Output to graphviz format string.
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
}
} // namespace gtsam
#include <gtsam/inference/BayesNet-inst.h>

View File

@ -16,30 +16,41 @@
* @date December, 2021
*/
#include <gtsam/base/Vector.h>
#include <gtsam/inference/DotWriter.h>
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <ostream>
using namespace std;
namespace gtsam {
void DotWriter::writePreamble(ostream* os) const {
void DotWriter::graphPreamble(ostream* os) const {
*os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
void DotWriter::digraphPreamble(ostream* os) const {
*os << "digraph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
ostream* os) {
ostream* os) const {
// Label the node with the label from the KeyFormatter
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
<< "\"";
if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
}
if (boxes.count(key)) {
*os << ", shape=box";
}
*os << "];\n";
}
@ -53,18 +64,35 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
}
static void ConnectVariables(Key key1, Key key2,
const KeyFormatter& keyFormatter,
ostream* os) {
const KeyFormatter& keyFormatter, ostream* os) {
*os << " var" << keyFormatter(key1) << "--"
<< "var" << keyFormatter(key2) << ";\n";
}
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
size_t i, ostream* os) {
size_t i, ostream* os) {
*os << " var" << keyFormatter(key) << "--"
<< "factor" << i << ";\n";
}
/// Return variable position or none
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
boost::optional<Vector2> result = boost::none;
// Check position hint
Symbol symbol(key);
auto hint = positionHints.find(symbol.chr());
if (hint != positionHints.end())
result.reset(Vector2(symbol.index(), hint->second));
// Override with explicit position, if given.
auto pos = variablePositions.find(key);
if (pos != variablePositions.end())
result.reset(pos->second);
return result;
}
void DotWriter::processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
@ -74,7 +102,10 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
ConnectVariables(keys[0], keys[1], keyFormatter, os);
} else {
// Create dot for the factor.
DrawFactor(i, position, os);
if (!position && factorPositions.count(i))
DrawFactor(i, factorPositions.at(i), os);
else
DrawFactor(i, position, os);
// Make factor-variable connections
if (connectKeysToFactor) {

View File

@ -23,10 +23,15 @@
#include <gtsam/inference/Key.h>
#include <iosfwd>
#include <map>
#include <set>
namespace gtsam {
/// Graphviz formatter.
/**
* @brief DotWriter is a helper class for writing graphviz .dot files.
* @addtogroup inference
*/
struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches
@ -35,6 +40,28 @@ struct GTSAM_EXPORT DotWriter {
///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors
/**
* Variable positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<Key, Vector2> variablePositions;
/**
* The position hints allow one to use symbol character and index to specify
* position. Unless variable positions are specified, if a hint is present for
* a given symbol, it will be used to calculate the positions as (index,hint).
*/
std::map<char, double> positionHints;
/** A set of keys that will be displayed as a box */
std::set<Key> boxes;
/**
* Factor positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<size_t, Vector2> factorPositions;
explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5,
bool plotFactorPoints = true,
@ -45,18 +72,24 @@ struct GTSAM_EXPORT DotWriter {
connectKeysToFactor(connectKeysToFactor),
binaryEdges(binaryEdges) {}
/// Write out preamble, including size.
void writePreamble(std::ostream* os) const;
/// Write out preamble for graph, including size.
void graphPreamble(std::ostream* os) const;
/// Write out preamble for digraph, including size.
void digraphPreamble(std::ostream* os) const;
/// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
std::ostream* os);
void drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
std::ostream* os) const;
/// Create factor dot.
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os);
/// Return variable position or none
boost::optional<Vector2> variablePos(Key key) const;
/// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter,

View File

@ -131,11 +131,12 @@ template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
writer.writePreamble(&os);
writer.graphPreamble(&os);
// Create nodes for each variable in the graph
for (Key key : keys()) {
writer.DrawVariable(key, keyFormatter, boost::none, &os);
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";

View File

@ -127,6 +127,11 @@ class DotWriter {
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
std::map<char, double> positionHints;
std::set<Key> boxes;
std::map<size_t, gtsam::Vector2> factorPositions;
};
#include <gtsam/inference/VariableIndex.h>

View File

@ -205,23 +205,5 @@ namespace gtsam {
}
/* ************************************************************************* */
void GaussianBayesNet::saveGraph(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional : boost::adaptors::reverse(*this)) {
typename GaussianConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
typename GaussianConditional::Parents parents = conditional->parents();
for (Key p : parents)
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -21,17 +21,22 @@
#pragma once
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/global_includes.h>
#include <utility>
namespace gtsam {
/** A Bayes net made from linear-Gaussian densities */
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional>
/**
* GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
* @addtogroup linear
*/
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
{
public:
typedef FactorGraph<GaussianConditional> Base;
typedef BayesNet<GaussianConditional> Base;
typedef GaussianBayesNet This;
typedef GaussianConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +49,21 @@ namespace gtsam {
GaussianBayesNet() {}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
template <class CONTAINER>
explicit GaussianBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor
virtual ~GaussianBayesNet() {}
@ -66,6 +76,13 @@ namespace gtsam {
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @}
/// @name Standard Interface
@ -180,23 +197,6 @@ namespace gtsam {
*/
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/**
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
* installed.
*
* @param s The name of the figure.
* @param keyFormatter Formatter to use for styling keys in the graph.
*/
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
DefaultKeyFormatter) const;
/// @}
private:

View File

@ -437,42 +437,53 @@ class GaussianFactorGraph {
pair<Matrix,Vector> hessian() const;
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/linear/GaussianConditional.h>
virtual class GaussianConditional : gtsam::JacobianFactor {
//Constructors
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas);
// Constructors
GaussianConditional(size_t key, Vector d, Matrix R,
const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
const gtsam::noiseModel::Diagonal* sigmas);
const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas);
size_t name2, Matrix T,
const gtsam::noiseModel::Diagonal* sigmas);
//Constructors with no noise model
// Constructors with no noise model
GaussianConditional(size_t key, Vector d, Matrix R);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T);
//Standard Interface
void print(string s = "GaussianConditional",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
// Standard Interface
void print(string s = "GaussianConditional",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
gtsam::Key firstFrontalKey() const;
// Advanced Interface
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const;
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
Matrix R() const;
Matrix S() const;
Vector d() const;
// Advanced Interface
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const;
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
Matrix R() const;
Matrix S() const;
Vector d() const;
// enabling serialization functionality
void serialize() const;
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/linear/GaussianDensity.h>
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
double logDeterminant() const;
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/linear/GaussianBayesTree.h>

View File

@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
TEST(GaussianBayesNet, Dot) {
GaussianBayesNet fragment;
DotWriter writer;
writer.variablePositions.emplace(_x_, Vector2(10, 20));
writer.variablePositions.emplace(_y_, Vector2(50, 20));
auto position = writer.variablePos(_x_);
CHECK(position);
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var11[label=\"11\", pos=\"10,20!\"];\n"
" var22[label=\"22\", pos=\"50,20!\"];\n"
"\n"
" var22->var11\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) {
if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key));
boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) {
if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y();
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
return min;
}
boost::optional<Vector2> GraphvizFormatting::operator()(
boost::optional<Vector2> GraphvizFormatting::extractPosition(
const Value& value) const {
Vector3 t;
if (const GenericValue<Pose2>* p =
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
return Vector2(x, y);
}
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min,
Key key) const {
if (!values.exists(key)) return boost::none;
boost::optional<Vector2> xy = operator()(values.at(key));
if (!values.exists(key)) return DotWriter::variablePos(key);
boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) {
xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y());
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
return xy;
}
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const {
if (factorPositions.size() == 0) return boost::none;

View File

@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
/// World axes to be assigned to paper axes
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis
double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
///< connectivity
/// (optional for each factor) Manually specify factor "dot" positions:
std::map<size_t, Vector2> factorPositions;
/// Default constructor sets up robot coordinates. Paper horizontal is robot
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
GraphvizFormatting()
@ -55,8 +52,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
// Find bounds
Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> extractPosition(const Value& value) const;
/// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,

View File

@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter,
const GraphvizFormatting& writer) const {
writer.writePreamble(&os);
writer.graphPreamble(&os);
// Find bounds (imperative)
KeySet keys = this->keys();
@ -111,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
// Create nodes for each variable in the graph
for (Key key : keys) {
auto position = writer.variablePos(values, min, key);
writer.DrawVariable(key, keyFormatter, position, &os);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";

View File

@ -43,12 +43,14 @@ namespace gtsam {
class ExpressionFactor;
/**
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds.
* Linearizing the non-linear factor graph creates a linear factor graph on the
* tangent vector space at the linearization point. Because the tangent space is a true
* vector space, the config type will be an VectorValues in that linearized factor graph.
* A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
* which derive from NonlinearFactor. The values structures are typically (in
* SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
* in non-linear manifolds. Linearizing the non-linear factor graph creates a
* linear factor graph on the tangent vector space at the linearization point.
* Because the tangent space is a true vector space, the config type will be
* an VectorValues in that linearized factor graph.
* @addtogroup nonlinear
*/
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
@ -58,6 +60,9 @@ namespace gtsam {
typedef NonlinearFactorGraph This;
typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard Constructors
/// @{
/** Default constructor */
NonlinearFactorGraph() {}
@ -76,6 +81,10 @@ namespace gtsam {
/// Destructor
virtual ~NonlinearFactorGraph() {}
/// @}
/// @name Testable
/// @{
/** print */
void print(
const std::string& str = "NonlinearFactorGraph: ",
@ -90,6 +99,10 @@ namespace gtsam {
/** Test equality */
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
double error(const Values& values) const;
@ -206,6 +219,7 @@ namespace gtsam {
emplace_shared<PriorFactor<T>>(key, prior, covariance);
}
/// @}
/// @name Graph Display
/// @{
@ -215,20 +229,19 @@ namespace gtsam {
/// Output to graphviz format, stream version, with Values/extra options.
void dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// Output to graphviz format string, with Values/extra options.
std::string dot(const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
std::string dot(
const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// output to file with graphviz format, with Values/extra options.
void saveGraph(const std::string& filename, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting =
GraphvizFormatting()) const;
void saveGraph(
const std::string& filename, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// @}
private:
@ -251,6 +264,8 @@ namespace gtsam {
public:
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated
/// @{
/** @deprecated */
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
@ -275,6 +290,7 @@ namespace gtsam {
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
saveGraph(filename, values, keyFormatter, graphvizFormatting);
}
/// @}
#endif
};

View File

@ -95,18 +95,17 @@ class NonlinearFactorGraph {
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
gtsam::NonlinearFactorGraph clone() const;
// enabling serialization functionality
void serialize() const;
string dot(
const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting());
void saveGraph(const string& s, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer =
GraphvizFormatting()) const;
const GraphvizFormatting& formatting = GraphvizFormatting());
void saveGraph(
const string& s, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& formatting = GraphvizFormatting()) const;
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/nonlinear/NonlinearFactor.h>

View File

@ -16,41 +16,16 @@
* @author Richard Roberts
*/
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
#include <gtsam/symbolic/SymbolicBayesNet.h>
namespace gtsam {
// Instantiate base class
template class FactorGraph<SymbolicConditional>;
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
{
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional: boost::adaptors::reverse(*this)) {
SymbolicConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
SymbolicConditional::Parents parents = conditional->parents();
for(Key p: parents)
of << p << "->" << me << std::endl;
}
of << "}";
of.close();
}
// Instantiate base class
template class FactorGraph<SymbolicConditional>;
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
} // namespace gtsam

View File

@ -19,19 +19,19 @@
#pragma once
#include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/base/types.h>
namespace gtsam {
/** Symbolic Bayes Net
* \nosubgrouping
/**
* A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
* @addtogroup symbolic
*/
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> {
public:
typedef FactorGraph<SymbolicConditional> Base;
class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
public:
typedef BayesNet<SymbolicConditional> Base;
typedef SymbolicBayesNet This;
typedef SymbolicConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +44,21 @@ namespace gtsam {
SymbolicBayesNet() {}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
template <typename ITERATOR>
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
template <class CONTAINER>
explicit SymbolicBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/** Implicit copy/downcast constructor to override explicit template
* container constructor */
template <class DERIVEDCONDITIONAL>
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor
virtual ~SymbolicBayesNet() {}
@ -75,13 +80,6 @@ namespace gtsam {
/// @}
/// @name Standard Interface
/// @{
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;

View File

@ -77,6 +77,14 @@ virtual class SymbolicFactorGraph {
const gtsam::KeyVector& key_vector,
const gtsam::Ordering& marginalizedVariableOrdering);
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/symbolic/SymbolicConditional.h>
@ -98,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
// Standard interface
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const;
size_t nrParents() const;
};
@ -120,6 +129,14 @@ class SymbolicBayesNet {
gtsam::SymbolicConditional* back() const;
void push_back(gtsam::SymbolicConditional* conditional);
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h>

View File

@ -15,13 +15,16 @@
* @author Frank Dellaert
*/
#include <boost/make_shared.hpp>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <gtsam/base/VectorSpace.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <boost/make_shared.hpp>
using namespace std;
using namespace gtsam;
@ -30,7 +33,6 @@ static const Key _L_ = 0;
static const Key _A_ = 1;
static const Key _B_ = 2;
static const Key _C_ = 3;
static const Key _D_ = 4;
static SymbolicConditional::shared_ptr
B(new SymbolicConditional(_B_)),
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
}
/* ************************************************************************* */
TEST(SymbolicBayesNet, saveGraph) {
TEST(SymbolicBayesNet, Dot) {
using symbol_shorthand::A;
using symbol_shorthand::X;
SymbolicBayesNet bn;
bn += SymbolicConditional(_A_, _B_);
KeyVector keys {_B_, _C_, _D_};
bn += SymbolicConditional::FromKeys(keys,2);
bn += SymbolicConditional(_D_);
bn += SymbolicConditional(X(3), X(2), A(2));
bn += SymbolicConditional(X(2), X(1), A(1));
bn += SymbolicConditional(X(1));
bn.saveGraph("SymbolicBayesNet.dot");
DotWriter writer;
writer.positionHints.emplace('a', 2);
writer.positionHints.emplace('x', 1);
writer.boxes.emplace(A(1));
writer.boxes.emplace(A(2));
auto position = writer.variablePos(A(1));
CHECK(position);
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
string actual = bn.dot(DefaultKeyFormatter, writer);
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
"\n"
" varx1->varx2\n"
" vara1->varx2\n"
" varx2->varx3\n"
" vara2->varx3\n"
"}");
}
/* ************************************************************************* */

View File

@ -78,7 +78,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
formatting=graphviz_formatting),
textwrap.dedent(expected_result))
def test_factor_points(self):
@ -100,7 +100,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.plotFactorPoints = False
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
formatting=graphviz_formatting),
textwrap.dedent(expected_result))
def test_width_height(self):
@ -127,7 +127,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.figureHeightInches = 10
self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting),
formatting=graphviz_formatting),
textwrap.dedent(expected_result))