Merge pull request #1070 from borglab/feauture/dotwriter
commit
84aed900c9
|
@ -31,11 +31,12 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/** A Bayes net made from discrete conditional distributions. */
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
/**
|
||||
* A Bayes net made from discrete conditional distributions.
|
||||
* @addtogroup discrete
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||
public:
|
||||
typedef BayesNet<DiscreteConditional> Base;
|
||||
typedef DiscreteBayesNet This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
|
@ -49,16 +50,20 @@ namespace gtsam {
|
|||
DiscreteBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
template <typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
template <class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||
: Base(conditionals) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
/** Implicit copy/downcast constructor to override explicit template
|
||||
* container constructor */
|
||||
template <class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||
: Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteBayesNet() {}
|
||||
|
|
|
@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
void printSignature(
|
||||
|
@ -156,13 +157,17 @@ class DiscreteBayesNet {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
@ -252,14 +257,6 @@ class DiscreteFactorGraph {
|
|||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
@ -281,6 +278,14 @@ class DiscreteFactorGraph {
|
|||
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
|
|||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
string actual = fragment.dot();
|
||||
cout << actual << endl;
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0->3\n"
|
||||
"4->6\n"
|
||||
"3->5\n"
|
||||
"6->5\n"
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"0\"];\n"
|
||||
" var3[label=\"3\"];\n"
|
||||
" var4[label=\"4\"];\n"
|
||||
" var5[label=\"5\"];\n"
|
||||
" var6[label=\"6\"];\n"
|
||||
"\n"
|
||||
" var3->var5\n"
|
||||
" var6->var5\n"
|
||||
" var4->var6\n"
|
||||
" var0->var3\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
|
|
|
@ -10,41 +10,51 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file BayesNet.h
|
||||
* @brief Bayes network
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
* @file BayesNet.h
|
||||
* @brief Bayes network
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
#include <boost/range/adaptor/reversed.hpp>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::print(
|
||||
const std::string& s, const KeyFormatter& formatter) const {
|
||||
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
os << "digraph G{\n";
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
writer.digraphPreamble(&os);
|
||||
|
||||
for (auto conditional : *this) {
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : this->keys()) {
|
||||
auto position = writer.variablePos(key);
|
||||
writer.drawVariable(key, keyFormatter, position, &os);
|
||||
}
|
||||
os << "\n";
|
||||
|
||||
// Reverse order as typically Bayes nets stored in reverse topological sort.
|
||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||
auto frontals = conditional->frontals();
|
||||
const Key me = frontals.front();
|
||||
auto parents = conditional->parents();
|
||||
for (const Key& p : parents)
|
||||
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
|
||||
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
|
||||
}
|
||||
|
||||
os << "}";
|
||||
|
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
|
||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, keyFormatter);
|
||||
dot(ss, keyFormatter, writer);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, keyFormatter);
|
||||
dot(of, keyFormatter, writer);
|
||||
of.close();
|
||||
}
|
||||
|
||||
|
|
|
@ -10,77 +10,79 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file BayesNet.h
|
||||
* @brief Bayes network
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
* @file BayesNet.h
|
||||
* @brief Bayes network
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* A BayesNet is a tree of conditionals, stored in elimination order.
|
||||
*
|
||||
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
|
||||
* \nosubgrouping
|
||||
*/
|
||||
template<class CONDITIONAL>
|
||||
class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||
/**
|
||||
* A BayesNet is a tree of conditionals, stored in elimination order.
|
||||
* @addtogroup inference
|
||||
*/
|
||||
template <class CONDITIONAL>
|
||||
class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||
private:
|
||||
typedef FactorGraph<CONDITIONAL> Base;
|
||||
|
||||
private:
|
||||
public:
|
||||
typedef typename boost::shared_ptr<CONDITIONAL>
|
||||
sharedConditional; ///< A shared pointer to a conditional
|
||||
|
||||
typedef FactorGraph<CONDITIONAL> Base;
|
||||
protected:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
public:
|
||||
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
|
||||
/** Default constructor as an empty BayesNet */
|
||||
BayesNet() {}
|
||||
|
||||
protected:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
/** Construct from iterator over conditionals */
|
||||
template <typename ITERATOR>
|
||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Default constructor as an empty BayesNet */
|
||||
BayesNet() {};
|
||||
/// @}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
public:
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// @}
|
||||
/** print out graph */
|
||||
void print(
|
||||
const std::string& s = "BayesNet",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
public:
|
||||
/// @name Testable
|
||||
/// @{
|
||||
/// @}
|
||||
|
||||
/** print out graph */
|
||||
void print(
|
||||
const std::string& s = "BayesNet",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
/// @}
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const DotWriter& writer = DotWriter()) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// @}
|
||||
};
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
||||
#include <gtsam/inference/BayesNet-inst.h>
|
||||
|
|
|
@ -16,30 +16,41 @@
|
|||
* @date December, 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <ostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DotWriter::writePreamble(ostream* os) const {
|
||||
void DotWriter::graphPreamble(ostream* os) const {
|
||||
*os << "graph {\n";
|
||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||
<< "\";\n\n";
|
||||
}
|
||||
|
||||
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
void DotWriter::digraphPreamble(ostream* os) const {
|
||||
*os << "digraph {\n";
|
||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||
<< "\";\n\n";
|
||||
}
|
||||
|
||||
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
ostream* os) {
|
||||
ostream* os) const {
|
||||
// Label the node with the label from the KeyFormatter
|
||||
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
|
||||
<< "\"";
|
||||
if (position) {
|
||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||
}
|
||||
if (boxes.count(key)) {
|
||||
*os << ", shape=box";
|
||||
}
|
||||
*os << "];\n";
|
||||
}
|
||||
|
||||
|
@ -53,18 +64,35 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
|||
}
|
||||
|
||||
static void ConnectVariables(Key key1, Key key2,
|
||||
const KeyFormatter& keyFormatter,
|
||||
ostream* os) {
|
||||
const KeyFormatter& keyFormatter, ostream* os) {
|
||||
*os << " var" << keyFormatter(key1) << "--"
|
||||
<< "var" << keyFormatter(key2) << ";\n";
|
||||
}
|
||||
|
||||
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
||||
size_t i, ostream* os) {
|
||||
size_t i, ostream* os) {
|
||||
*os << " var" << keyFormatter(key) << "--"
|
||||
<< "factor" << i << ";\n";
|
||||
}
|
||||
|
||||
/// Return variable position or none
|
||||
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
|
||||
boost::optional<Vector2> result = boost::none;
|
||||
|
||||
// Check position hint
|
||||
Symbol symbol(key);
|
||||
auto hint = positionHints.find(symbol.chr());
|
||||
if (hint != positionHints.end())
|
||||
result.reset(Vector2(symbol.index(), hint->second));
|
||||
|
||||
// Override with explicit position, if given.
|
||||
auto pos = variablePositions.find(key);
|
||||
if (pos != variablePositions.end())
|
||||
result.reset(pos->second);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||
const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
|
@ -74,7 +102,10 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
|||
ConnectVariables(keys[0], keys[1], keyFormatter, os);
|
||||
} else {
|
||||
// Create dot for the factor.
|
||||
DrawFactor(i, position, os);
|
||||
if (!position && factorPositions.count(i))
|
||||
DrawFactor(i, factorPositions.at(i), os);
|
||||
else
|
||||
DrawFactor(i, position, os);
|
||||
|
||||
// Make factor-variable connections
|
||||
if (connectKeysToFactor) {
|
||||
|
|
|
@ -23,10 +23,15 @@
|
|||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <iosfwd>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Graphviz formatter.
|
||||
/**
|
||||
* @brief DotWriter is a helper class for writing graphviz .dot files.
|
||||
* @addtogroup inference
|
||||
*/
|
||||
struct GTSAM_EXPORT DotWriter {
|
||||
double figureWidthInches; ///< The figure width on paper in inches
|
||||
double figureHeightInches; ///< The figure height on paper in inches
|
||||
|
@ -35,6 +40,28 @@ struct GTSAM_EXPORT DotWriter {
|
|||
///< the dot of the factor
|
||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||
|
||||
/**
|
||||
* Variable positions can be optionally specified and will be included in the
|
||||
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||
*/
|
||||
std::map<Key, Vector2> variablePositions;
|
||||
|
||||
/**
|
||||
* The position hints allow one to use symbol character and index to specify
|
||||
* position. Unless variable positions are specified, if a hint is present for
|
||||
* a given symbol, it will be used to calculate the positions as (index,hint).
|
||||
*/
|
||||
std::map<char, double> positionHints;
|
||||
|
||||
/** A set of keys that will be displayed as a box */
|
||||
std::set<Key> boxes;
|
||||
|
||||
/**
|
||||
* Factor positions can be optionally specified and will be included in the
|
||||
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||
*/
|
||||
std::map<size_t, Vector2> factorPositions;
|
||||
|
||||
explicit DotWriter(double figureWidthInches = 5,
|
||||
double figureHeightInches = 5,
|
||||
bool plotFactorPoints = true,
|
||||
|
@ -45,18 +72,24 @@ struct GTSAM_EXPORT DotWriter {
|
|||
connectKeysToFactor(connectKeysToFactor),
|
||||
binaryEdges(binaryEdges) {}
|
||||
|
||||
/// Write out preamble, including size.
|
||||
void writePreamble(std::ostream* os) const;
|
||||
/// Write out preamble for graph, including size.
|
||||
void graphPreamble(std::ostream* os) const;
|
||||
|
||||
/// Write out preamble for digraph, including size.
|
||||
void digraphPreamble(std::ostream* os) const;
|
||||
|
||||
/// Create a variable dot fragment.
|
||||
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
std::ostream* os);
|
||||
void drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
std::ostream* os) const;
|
||||
|
||||
/// Create factor dot.
|
||||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||
std::ostream* os);
|
||||
|
||||
/// Return variable position or none
|
||||
boost::optional<Vector2> variablePos(Key key) const;
|
||||
|
||||
/// Draw a single factor, specified by its index i and its variable keys.
|
||||
void processFactor(size_t i, const KeyVector& keys,
|
||||
const KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -131,11 +131,12 @@ template <class FACTOR>
|
|||
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DotWriter& writer) const {
|
||||
writer.writePreamble(&os);
|
||||
writer.graphPreamble(&os);
|
||||
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : keys()) {
|
||||
writer.DrawVariable(key, keyFormatter, boost::none, &os);
|
||||
auto position = writer.variablePos(key);
|
||||
writer.drawVariable(key, keyFormatter, position, &os);
|
||||
}
|
||||
os << "\n";
|
||||
|
||||
|
|
|
@ -127,6 +127,11 @@ class DotWriter {
|
|||
bool plotFactorPoints;
|
||||
bool connectKeysToFactor;
|
||||
bool binaryEdges;
|
||||
|
||||
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
|
||||
std::map<char, double> positionHints;
|
||||
std::set<Key> boxes;
|
||||
std::map<size_t, gtsam::Vector2> factorPositions;
|
||||
};
|
||||
|
||||
#include <gtsam/inference/VariableIndex.h>
|
||||
|
|
|
@ -205,23 +205,5 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void GaussianBayesNet::saveGraph(const std::string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(s.c_str());
|
||||
of << "digraph G{\n";
|
||||
|
||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||
typename GaussianConditional::Frontals frontals = conditional->frontals();
|
||||
Key me = frontals.front();
|
||||
typename GaussianConditional::Parents parents = conditional->parents();
|
||||
for (Key p : parents)
|
||||
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
|
||||
}
|
||||
|
||||
of << "}";
|
||||
of.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -21,17 +21,22 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/global_includes.h>
|
||||
|
||||
#include <utility>
|
||||
namespace gtsam {
|
||||
|
||||
/** A Bayes net made from linear-Gaussian densities */
|
||||
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional>
|
||||
/**
|
||||
* GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
|
||||
* @addtogroup linear
|
||||
*/
|
||||
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef FactorGraph<GaussianConditional> Base;
|
||||
typedef BayesNet<GaussianConditional> Base;
|
||||
typedef GaussianBayesNet This;
|
||||
typedef GaussianConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
@ -44,16 +49,21 @@ namespace gtsam {
|
|||
GaussianBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
template <typename ITERATOR>
|
||||
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
template <class CONTAINER>
|
||||
explicit GaussianBayesNet(const CONTAINER& conditionals) {
|
||||
push_back(conditionals);
|
||||
}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
/** Implicit copy/downcast constructor to override explicit template
|
||||
* container constructor */
|
||||
template <class DERIVEDCONDITIONAL>
|
||||
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||
: Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~GaussianBayesNet() {}
|
||||
|
@ -66,6 +76,13 @@ namespace gtsam {
|
|||
/** Check equality */
|
||||
bool equals(const This& bn, double tol = 1e-9) const;
|
||||
|
||||
/// print graph
|
||||
void print(
|
||||
const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
|
@ -180,23 +197,6 @@ namespace gtsam {
|
|||
*/
|
||||
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
|
||||
|
||||
/// print graph
|
||||
void print(
|
||||
const std::string& s = "",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
|
||||
* installed.
|
||||
*
|
||||
* @param s The name of the figure.
|
||||
* @param keyFormatter Formatter to use for styling keys in the graph.
|
||||
*/
|
||||
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
|
||||
DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
@ -437,42 +437,53 @@ class GaussianFactorGraph {
|
|||
pair<Matrix,Vector> hessian() const;
|
||||
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
};
|
||||
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||
//Constructors
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas);
|
||||
// Constructors
|
||||
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||
const gtsam::noiseModel::Diagonal* sigmas);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||
const gtsam::noiseModel::Diagonal* sigmas);
|
||||
const gtsam::noiseModel::Diagonal* sigmas);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas);
|
||||
size_t name2, Matrix T,
|
||||
const gtsam::noiseModel::Diagonal* sigmas);
|
||||
|
||||
//Constructors with no noise model
|
||||
// Constructors with no noise model
|
||||
GaussianConditional(size_t key, Vector d, Matrix R);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||
size_t name2, Matrix T);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
|
||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||
size_t name2, Matrix T);
|
||||
|
||||
//Standard Interface
|
||||
void print(string s = "GaussianConditional",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||
// Standard Interface
|
||||
void print(string s = "GaussianConditional",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
|
||||
// Advanced Interface
|
||||
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
||||
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
||||
const gtsam::VectorValues& rhs) const;
|
||||
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
|
||||
Matrix R() const;
|
||||
Matrix S() const;
|
||||
Vector d() const;
|
||||
|
||||
// Advanced Interface
|
||||
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
||||
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
|
||||
const gtsam::VectorValues& rhs) const;
|
||||
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
|
||||
Matrix R() const;
|
||||
Matrix S() const;
|
||||
Vector d() const;
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
};
|
||||
|
||||
#include <gtsam/linear/GaussianDensity.h>
|
||||
|
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
|
|||
double logDeterminant() const;
|
||||
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
|
||||
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
};
|
||||
|
||||
#include <gtsam/linear/GaussianBayesTree.h>
|
||||
|
|
|
@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||
TEST(GaussianBayesNet, Dot) {
|
||||
GaussianBayesNet fragment;
|
||||
DotWriter writer;
|
||||
writer.variablePositions.emplace(_x_, Vector2(10, 20));
|
||||
writer.variablePositions.emplace(_y_, Vector2(50, 20));
|
||||
|
||||
auto position = writer.variablePos(_x_);
|
||||
CHECK(position);
|
||||
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
||||
|
||||
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
||||
EXPECT(actual ==
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var11[label=\"11\", pos=\"10,20!\"];\n"
|
||||
" var22[label=\"22\", pos=\"50,20!\"];\n"
|
||||
"\n"
|
||||
" var22->var11\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
|||
min.y() = std::numeric_limits<double>::infinity();
|
||||
for (const Key& key : keys) {
|
||||
if (values.exists(key)) {
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||
if (xy) {
|
||||
if (xy->x() < min.x()) min.x() = xy->x();
|
||||
if (xy->y() < min.y()) min.y() = xy->y();
|
||||
|
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
|||
return min;
|
||||
}
|
||||
|
||||
boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||
boost::optional<Vector2> GraphvizFormatting::extractPosition(
|
||||
const Value& value) const {
|
||||
Vector3 t;
|
||||
if (const GenericValue<Pose2>* p =
|
||||
|
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
|
|||
return Vector2(x, y);
|
||||
}
|
||||
|
||||
// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||
const Vector2& min,
|
||||
Key key) const {
|
||||
if (!values.exists(key)) return boost::none;
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
if (!values.exists(key)) return DotWriter::variablePos(key);
|
||||
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||
if (xy) {
|
||||
xy->x() = scale * (xy->x() - min.x());
|
||||
xy->y() = scale * (xy->y() - min.y());
|
||||
|
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
|||
return xy;
|
||||
}
|
||||
|
||||
// Return affinely transformed factor position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
||||
size_t i) const {
|
||||
if (factorPositions.size() == 0) return boost::none;
|
||||
|
|
|
@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
|||
/// World axes to be assigned to paper axes
|
||||
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
|
||||
|
||||
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
|
||||
///< paper axis
|
||||
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
|
||||
///< axis
|
||||
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
|
||||
///< paper axis
|
||||
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
|
||||
///< axis
|
||||
double scale; ///< Scale all positions to reduce / increase density
|
||||
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
||||
///< connectivity
|
||||
|
||||
/// (optional for each factor) Manually specify factor "dot" positions:
|
||||
std::map<size_t, Vector2> factorPositions;
|
||||
|
||||
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
||||
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
||||
GraphvizFormatting()
|
||||
|
@ -55,8 +52,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
|||
// Find bounds
|
||||
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
||||
|
||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||
boost::optional<Vector2> operator()(const Value& value) const;
|
||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||
boost::optional<Vector2> extractPosition(const Value& value) const;
|
||||
|
||||
/// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
||||
|
|
|
@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
|
|||
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||
const KeyFormatter& keyFormatter,
|
||||
const GraphvizFormatting& writer) const {
|
||||
writer.writePreamble(&os);
|
||||
writer.graphPreamble(&os);
|
||||
|
||||
// Find bounds (imperative)
|
||||
KeySet keys = this->keys();
|
||||
|
@ -111,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
|||
// Create nodes for each variable in the graph
|
||||
for (Key key : keys) {
|
||||
auto position = writer.variablePos(values, min, key);
|
||||
writer.DrawVariable(key, keyFormatter, position, &os);
|
||||
writer.drawVariable(key, keyFormatter, position, &os);
|
||||
}
|
||||
os << "\n";
|
||||
|
||||
|
|
|
@ -43,12 +43,14 @@ namespace gtsam {
|
|||
class ExpressionFactor;
|
||||
|
||||
/**
|
||||
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
|
||||
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds.
|
||||
* Linearizing the non-linear factor graph creates a linear factor graph on the
|
||||
* tangent vector space at the linearization point. Because the tangent space is a true
|
||||
* vector space, the config type will be an VectorValues in that linearized factor graph.
|
||||
* A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||
* which derive from NonlinearFactor. The values structures are typically (in
|
||||
* SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
|
||||
* in non-linear manifolds. Linearizing the non-linear factor graph creates a
|
||||
* linear factor graph on the tangent vector space at the linearization point.
|
||||
* Because the tangent space is a true vector space, the config type will be
|
||||
* an VectorValues in that linearized factor graph.
|
||||
* @addtogroup nonlinear
|
||||
*/
|
||||
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
||||
|
||||
|
@ -58,6 +60,9 @@ namespace gtsam {
|
|||
typedef NonlinearFactorGraph This;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor */
|
||||
NonlinearFactorGraph() {}
|
||||
|
||||
|
@ -76,6 +81,10 @@ namespace gtsam {
|
|||
/// Destructor
|
||||
virtual ~NonlinearFactorGraph() {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** print */
|
||||
void print(
|
||||
const std::string& str = "NonlinearFactorGraph: ",
|
||||
|
@ -90,6 +99,10 @@ namespace gtsam {
|
|||
/** Test equality */
|
||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
||||
double error(const Values& values) const;
|
||||
|
||||
|
@ -206,6 +219,7 @@ namespace gtsam {
|
|||
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
|
@ -215,20 +229,19 @@ namespace gtsam {
|
|||
/// Output to graphviz format, stream version, with Values/extra options.
|
||||
void dot(std::ostream& os, const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
|
||||
/// Output to graphviz format string, with Values/extra options.
|
||||
std::string dot(const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
std::string dot(
|
||||
const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
|
||||
/// output to file with graphviz format, with Values/extra options.
|
||||
void saveGraph(const std::string& filename, const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& graphvizFormatting =
|
||||
GraphvizFormatting()) const;
|
||||
void saveGraph(
|
||||
const std::string& filename, const Values& values,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
@ -251,6 +264,8 @@ namespace gtsam {
|
|||
public:
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated
|
||||
/// @{
|
||||
/** @deprecated */
|
||||
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
||||
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
||||
|
@ -275,6 +290,7 @@ namespace gtsam {
|
|||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
||||
}
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
};
|
||||
|
|
|
@ -95,18 +95,17 @@ class NonlinearFactorGraph {
|
|||
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
|
||||
gtsam::NonlinearFactorGraph clone() const;
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
|
||||
string dot(
|
||||
const gtsam::Values& values,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const GraphvizFormatting& writer = GraphvizFormatting());
|
||||
void saveGraph(const string& s, const gtsam::Values& values,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter,
|
||||
const GraphvizFormatting& writer =
|
||||
GraphvizFormatting()) const;
|
||||
const GraphvizFormatting& formatting = GraphvizFormatting());
|
||||
void saveGraph(
|
||||
const string& s, const gtsam::Values& values,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const GraphvizFormatting& formatting = GraphvizFormatting()) const;
|
||||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
};
|
||||
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
|
|
|
@ -16,41 +16,16 @@
|
|||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
#include <boost/range/adaptor/reversed.hpp>
|
||||
#include <fstream>
|
||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class FactorGraph<SymbolicConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool SymbolicBayesNet::equals(const This& bn, double tol) const
|
||||
{
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
|
||||
{
|
||||
std::ofstream of(s.c_str());
|
||||
of << "digraph G{\n";
|
||||
|
||||
for (auto conditional: boost::adaptors::reverse(*this)) {
|
||||
SymbolicConditional::Frontals frontals = conditional->frontals();
|
||||
Key me = frontals.front();
|
||||
SymbolicConditional::Parents parents = conditional->parents();
|
||||
for(Key p: parents)
|
||||
of << p << "->" << me << std::endl;
|
||||
}
|
||||
|
||||
of << "}";
|
||||
of.close();
|
||||
}
|
||||
|
||||
// Instantiate base class
|
||||
template class FactorGraph<SymbolicConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -19,19 +19,19 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/base/types.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/** Symbolic Bayes Net
|
||||
* \nosubgrouping
|
||||
/**
|
||||
* A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
|
||||
* @addtogroup symbolic
|
||||
*/
|
||||
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> {
|
||||
|
||||
public:
|
||||
|
||||
typedef FactorGraph<SymbolicConditional> Base;
|
||||
class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
|
||||
public:
|
||||
typedef BayesNet<SymbolicConditional> Base;
|
||||
typedef SymbolicBayesNet This;
|
||||
typedef SymbolicConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
@ -44,16 +44,21 @@ namespace gtsam {
|
|||
SymbolicBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
template <typename ITERATOR>
|
||||
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||
: Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
template <class CONTAINER>
|
||||
explicit SymbolicBayesNet(const CONTAINER& conditionals) {
|
||||
push_back(conditionals);
|
||||
}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
/** Implicit copy/downcast constructor to override explicit template
|
||||
* container constructor */
|
||||
template <class DERIVEDCONDITIONAL>
|
||||
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||
: Base(graph) {}
|
||||
|
||||
/// Destructor
|
||||
virtual ~SymbolicBayesNet() {}
|
||||
|
@ -75,13 +80,6 @@ namespace gtsam {
|
|||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -77,6 +77,14 @@ virtual class SymbolicFactorGraph {
|
|||
const gtsam::KeyVector& key_vector,
|
||||
const gtsam::Ordering& marginalizedVariableOrdering);
|
||||
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
};
|
||||
|
||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||
|
@ -98,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
|
|||
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
|
||||
|
||||
// Standard interface
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
};
|
||||
|
@ -120,6 +129,14 @@ class SymbolicBayesNet {
|
|||
gtsam::SymbolicConditional* back() const;
|
||||
void push_back(gtsam::SymbolicConditional* conditional);
|
||||
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
|
||||
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
void saveGraph(
|
||||
string s,
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||
};
|
||||
|
||||
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
||||
|
|
|
@ -15,13 +15,16 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/base/VectorSpace.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
@ -30,7 +33,6 @@ static const Key _L_ = 0;
|
|||
static const Key _A_ = 1;
|
||||
static const Key _B_ = 2;
|
||||
static const Key _C_ = 3;
|
||||
static const Key _D_ = 4;
|
||||
|
||||
static SymbolicConditional::shared_ptr
|
||||
B(new SymbolicConditional(_B_)),
|
||||
|
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(SymbolicBayesNet, saveGraph) {
|
||||
TEST(SymbolicBayesNet, Dot) {
|
||||
using symbol_shorthand::A;
|
||||
using symbol_shorthand::X;
|
||||
SymbolicBayesNet bn;
|
||||
bn += SymbolicConditional(_A_, _B_);
|
||||
KeyVector keys {_B_, _C_, _D_};
|
||||
bn += SymbolicConditional::FromKeys(keys,2);
|
||||
bn += SymbolicConditional(_D_);
|
||||
bn += SymbolicConditional(X(3), X(2), A(2));
|
||||
bn += SymbolicConditional(X(2), X(1), A(1));
|
||||
bn += SymbolicConditional(X(1));
|
||||
|
||||
bn.saveGraph("SymbolicBayesNet.dot");
|
||||
DotWriter writer;
|
||||
writer.positionHints.emplace('a', 2);
|
||||
writer.positionHints.emplace('x', 1);
|
||||
writer.boxes.emplace(A(1));
|
||||
writer.boxes.emplace(A(2));
|
||||
|
||||
auto position = writer.variablePos(A(1));
|
||||
CHECK(position);
|
||||
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
|
||||
|
||||
string actual = bn.dot(DefaultKeyFormatter, writer);
|
||||
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
|
||||
EXPECT(actual ==
|
||||
"digraph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
|
||||
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
|
||||
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
|
||||
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
|
||||
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
|
||||
"\n"
|
||||
" varx1->varx2\n"
|
||||
" vara1->varx2\n"
|
||||
" varx2->varx3\n"
|
||||
" vara2->varx3\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -78,7 +78,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
|||
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
|
||||
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
|
||||
self.assertEqual(self.graph.dot(self.values,
|
||||
writer=graphviz_formatting),
|
||||
formatting=graphviz_formatting),
|
||||
textwrap.dedent(expected_result))
|
||||
|
||||
def test_factor_points(self):
|
||||
|
@ -100,7 +100,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
|||
graphviz_formatting.plotFactorPoints = False
|
||||
|
||||
self.assertEqual(self.graph.dot(self.values,
|
||||
writer=graphviz_formatting),
|
||||
formatting=graphviz_formatting),
|
||||
textwrap.dedent(expected_result))
|
||||
|
||||
def test_width_height(self):
|
||||
|
@ -127,7 +127,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
|||
graphviz_formatting.figureHeightInches = 10
|
||||
|
||||
self.assertEqual(self.graph.dot(self.values,
|
||||
writer=graphviz_formatting),
|
||||
formatting=graphviz_formatting),
|
||||
textwrap.dedent(expected_result))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue