Merge branch 'develop' into wrap-karcher-mean-rot3
commit
01f3fe50e4
|
@ -31,11 +31,12 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from discrete conditional distributions. */
|
/**
|
||||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
* A Bayes net made from discrete conditional distributions.
|
||||||
{
|
* @addtogroup discrete
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef BayesNet<DiscreteConditional> Base;
|
typedef BayesNet<DiscreteConditional> Base;
|
||||||
typedef DiscreteBayesNet This;
|
typedef DiscreteBayesNet This;
|
||||||
typedef DiscreteConditional ConditionalType;
|
typedef DiscreteConditional ConditionalType;
|
||||||
|
@ -50,15 +51,19 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template <typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template <class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit DiscreteBayesNet(const CONTAINER& conditionals)
|
||||||
|
: Base(conditionals) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
|
* container constructor */
|
||||||
template <class DERIVEDCONDITIONAL>
|
template <class DERIVEDCONDITIONAL>
|
||||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~DiscreteBayesNet() {}
|
virtual ~DiscreteBayesNet() {}
|
||||||
|
|
|
@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
size_t nrFrontals() const;
|
size_t nrFrontals() const;
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
|
@ -156,13 +157,17 @@ class DiscreteBayesNet {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
@ -228,19 +233,6 @@ class DiscreteLookupDAG {
|
||||||
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
|
||||||
class DotWriter {
|
|
||||||
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
|
||||||
bool plotFactorPoints = true, bool connectKeysToFactor = true,
|
|
||||||
bool binaryEdges = true);
|
|
||||||
|
|
||||||
double figureWidthInches;
|
|
||||||
double figureHeightInches;
|
|
||||||
bool plotFactorPoints;
|
|
||||||
bool connectKeysToFactor;
|
|
||||||
bool binaryEdges;
|
|
||||||
};
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
class DiscreteFactorGraph {
|
class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
|
@ -265,14 +257,6 @@ class DiscreteFactorGraph {
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
string dot(
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
|
||||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
|
||||||
void saveGraph(
|
|
||||||
string s,
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
|
||||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
|
|
||||||
|
|
||||||
gtsam::DecisionTreeFactor product() const;
|
gtsam::DecisionTreeFactor product() const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
@ -294,6 +278,14 @@ class DiscreteFactorGraph {
|
||||||
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
||||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
|
||||||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||||
|
|
||||||
string actual = fragment.dot();
|
string actual = fragment.dot();
|
||||||
|
cout << actual << endl;
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph {\n"
|
||||||
"0->3\n"
|
" size=\"5,5\";\n"
|
||||||
"4->6\n"
|
"\n"
|
||||||
"3->5\n"
|
" var0[label=\"0\"];\n"
|
||||||
"6->5\n"
|
" var3[label=\"3\"];\n"
|
||||||
|
" var4[label=\"4\"];\n"
|
||||||
|
" var5[label=\"5\"];\n"
|
||||||
|
" var6[label=\"6\"];\n"
|
||||||
|
"\n"
|
||||||
|
" var3->var5\n"
|
||||||
|
" var6->var5\n"
|
||||||
|
" var4->var6\n"
|
||||||
|
" var0->var3\n"
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,33 +18,43 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
#include <boost/range/adaptor/reversed.hpp>
|
#include <boost/range/adaptor/reversed.hpp>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::print(
|
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
||||||
const std::string& s, const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
Base::print(s, formatter);
|
Base::print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
os << "digraph G{\n";
|
const DotWriter& writer) const {
|
||||||
|
writer.digraphPreamble(&os);
|
||||||
|
|
||||||
for (auto conditional : *this) {
|
// Create nodes for each variable in the graph
|
||||||
|
for (Key key : this->keys()) {
|
||||||
|
auto position = writer.variablePos(key);
|
||||||
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
|
}
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
// Reverse order as typically Bayes nets stored in reverse topological sort.
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||||
auto frontals = conditional->frontals();
|
auto frontals = conditional->frontals();
|
||||||
const Key me = frontals.front();
|
const Key me = frontals.front();
|
||||||
auto parents = conditional->parents();
|
auto parents = conditional->parents();
|
||||||
for (const Key& p : parents)
|
for (const Key& p : parents)
|
||||||
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
|
os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "}";
|
os << "}";
|
||||||
|
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
|
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
|
||||||
|
const DotWriter& writer) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
dot(ss, keyFormatter);
|
dot(ss, keyFormatter, writer);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DotWriter& writer) const {
|
||||||
std::ofstream of(filename.c_str());
|
std::ofstream of(filename.c_str());
|
||||||
dot(of, keyFormatter);
|
dot(of, keyFormatter, writer);
|
||||||
of.close();
|
of.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,38 +18,37 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A BayesNet is a tree of conditionals, stored in elimination order.
|
* A BayesNet is a tree of conditionals, stored in elimination order.
|
||||||
*
|
* @addtogroup inference
|
||||||
* todo: how to handle Bayes nets with an optimize function? Currently using global functions.
|
|
||||||
* \nosubgrouping
|
|
||||||
*/
|
*/
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
class BayesNet : public FactorGraph<CONDITIONAL> {
|
class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
typedef FactorGraph<CONDITIONAL> Base;
|
typedef FactorGraph<CONDITIONAL> Base;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional
|
typedef typename boost::shared_ptr<CONDITIONAL>
|
||||||
|
sharedConditional; ///< A shared pointer to a conditional
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Default constructor as an empty BayesNet */
|
/** Default constructor as an empty BayesNet */
|
||||||
BayesNet() {};
|
BayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template <typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
@ -68,19 +67,22 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Output to graphviz format, stream version.
|
/// Output to graphviz format, stream version.
|
||||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
void dot(std::ostream& os,
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// Output to graphviz format string.
|
/// Output to graphviz format string.
|
||||||
std::string dot(
|
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// output to file with graphviz format.
|
/// output to file with graphviz format.
|
||||||
void saveGraph(const std::string& filename,
|
void saveGraph(const std::string& filename,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
const DotWriter& writer = DotWriter()) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
} // namespace gtsam
|
||||||
|
|
||||||
#include <gtsam/inference/BayesNet-inst.h>
|
#include <gtsam/inference/BayesNet-inst.h>
|
||||||
|
|
|
@ -16,30 +16,41 @@
|
||||||
* @date December, 2021
|
* @date December, 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
void DotWriter::writePreamble(ostream* os) const {
|
void DotWriter::graphPreamble(ostream* os) const {
|
||||||
*os << "graph {\n";
|
*os << "graph {\n";
|
||||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||||
<< "\";\n\n";
|
<< "\";\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
void DotWriter::digraphPreamble(ostream* os) const {
|
||||||
|
*os << "digraph {\n";
|
||||||
|
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||||
|
<< "\";\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
ostream* os) {
|
ostream* os) const {
|
||||||
// Label the node with the label from the KeyFormatter
|
// Label the node with the label from the KeyFormatter
|
||||||
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
|
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
|
||||||
<< "\"";
|
<< "\"";
|
||||||
if (position) {
|
if (position) {
|
||||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||||
}
|
}
|
||||||
|
if (boxes.count(key)) {
|
||||||
|
*os << ", shape=box";
|
||||||
|
}
|
||||||
*os << "];\n";
|
*os << "];\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,8 +64,7 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ConnectVariables(Key key1, Key key2,
|
static void ConnectVariables(Key key1, Key key2,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter, ostream* os) {
|
||||||
ostream* os) {
|
|
||||||
*os << " var" << keyFormatter(key1) << "--"
|
*os << " var" << keyFormatter(key1) << "--"
|
||||||
<< "var" << keyFormatter(key2) << ";\n";
|
<< "var" << keyFormatter(key2) << ";\n";
|
||||||
}
|
}
|
||||||
|
@ -65,6 +75,24 @@ static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
||||||
<< "factor" << i << ";\n";
|
<< "factor" << i << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return variable position or none
|
||||||
|
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
|
||||||
|
boost::optional<Vector2> result = boost::none;
|
||||||
|
|
||||||
|
// Check position hint
|
||||||
|
Symbol symbol(key);
|
||||||
|
auto hint = positionHints.find(symbol.chr());
|
||||||
|
if (hint != positionHints.end())
|
||||||
|
result.reset(Vector2(symbol.index(), hint->second));
|
||||||
|
|
||||||
|
// Override with explicit position, if given.
|
||||||
|
auto pos = variablePositions.find(key);
|
||||||
|
if (pos != variablePositions.end())
|
||||||
|
result.reset(pos->second);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
|
@ -74,6 +102,9 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
ConnectVariables(keys[0], keys[1], keyFormatter, os);
|
ConnectVariables(keys[0], keys[1], keyFormatter, os);
|
||||||
} else {
|
} else {
|
||||||
// Create dot for the factor.
|
// Create dot for the factor.
|
||||||
|
if (!position && factorPositions.count(i))
|
||||||
|
DrawFactor(i, factorPositions.at(i), os);
|
||||||
|
else
|
||||||
DrawFactor(i, position, os);
|
DrawFactor(i, position, os);
|
||||||
|
|
||||||
// Make factor-variable connections
|
// Make factor-variable connections
|
||||||
|
|
|
@ -23,10 +23,15 @@
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
#include <iosfwd>
|
#include <iosfwd>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Graphviz formatter.
|
/**
|
||||||
|
* @brief DotWriter is a helper class for writing graphviz .dot files.
|
||||||
|
* @addtogroup inference
|
||||||
|
*/
|
||||||
struct GTSAM_EXPORT DotWriter {
|
struct GTSAM_EXPORT DotWriter {
|
||||||
double figureWidthInches; ///< The figure width on paper in inches
|
double figureWidthInches; ///< The figure width on paper in inches
|
||||||
double figureHeightInches; ///< The figure height on paper in inches
|
double figureHeightInches; ///< The figure height on paper in inches
|
||||||
|
@ -35,6 +40,28 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
///< the dot of the factor
|
///< the dot of the factor
|
||||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Variable positions can be optionally specified and will be included in the
|
||||||
|
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||||
|
*/
|
||||||
|
std::map<Key, Vector2> variablePositions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The position hints allow one to use symbol character and index to specify
|
||||||
|
* position. Unless variable positions are specified, if a hint is present for
|
||||||
|
* a given symbol, it will be used to calculate the positions as (index,hint).
|
||||||
|
*/
|
||||||
|
std::map<char, double> positionHints;
|
||||||
|
|
||||||
|
/** A set of keys that will be displayed as a box */
|
||||||
|
std::set<Key> boxes;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factor positions can be optionally specified and will be included in the
|
||||||
|
* dot file with a "!' sign, so "neato" can use it to render them.
|
||||||
|
*/
|
||||||
|
std::map<size_t, Vector2> factorPositions;
|
||||||
|
|
||||||
explicit DotWriter(double figureWidthInches = 5,
|
explicit DotWriter(double figureWidthInches = 5,
|
||||||
double figureHeightInches = 5,
|
double figureHeightInches = 5,
|
||||||
bool plotFactorPoints = true,
|
bool plotFactorPoints = true,
|
||||||
|
@ -45,18 +72,24 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
connectKeysToFactor(connectKeysToFactor),
|
connectKeysToFactor(connectKeysToFactor),
|
||||||
binaryEdges(binaryEdges) {}
|
binaryEdges(binaryEdges) {}
|
||||||
|
|
||||||
/// Write out preamble, including size.
|
/// Write out preamble for graph, including size.
|
||||||
void writePreamble(std::ostream* os) const;
|
void graphPreamble(std::ostream* os) const;
|
||||||
|
|
||||||
|
/// Write out preamble for digraph, including size.
|
||||||
|
void digraphPreamble(std::ostream* os) const;
|
||||||
|
|
||||||
/// Create a variable dot fragment.
|
/// Create a variable dot fragment.
|
||||||
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
void drawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
std::ostream* os);
|
std::ostream* os) const;
|
||||||
|
|
||||||
/// Create factor dot.
|
/// Create factor dot.
|
||||||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
std::ostream* os);
|
std::ostream* os);
|
||||||
|
|
||||||
|
/// Return variable position or none
|
||||||
|
boost::optional<Vector2> variablePos(Key key) const;
|
||||||
|
|
||||||
/// Draw a single factor, specified by its index i and its variable keys.
|
/// Draw a single factor, specified by its index i and its variable keys.
|
||||||
void processFactor(size_t i, const KeyVector& keys,
|
void processFactor(size_t i, const KeyVector& keys,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -131,11 +131,12 @@ template <class FACTOR>
|
||||||
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const DotWriter& writer) const {
|
const DotWriter& writer) const {
|
||||||
writer.writePreamble(&os);
|
writer.graphPreamble(&os);
|
||||||
|
|
||||||
// Create nodes for each variable in the graph
|
// Create nodes for each variable in the graph
|
||||||
for (Key key : keys()) {
|
for (Key key : keys()) {
|
||||||
writer.DrawVariable(key, keyFormatter, boost::none, &os);
|
auto position = writer.variablePos(key);
|
||||||
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
}
|
}
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,168 @@
|
||||||
|
//*************************************************************************
|
||||||
|
// inference
|
||||||
|
//*************************************************************************
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
// Default keyformatter
|
||||||
|
void PrintKeyList(
|
||||||
|
const gtsam::KeyList& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
void PrintKeyVector(
|
||||||
|
const gtsam::KeyVector& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
void PrintKeySet(
|
||||||
|
const gtsam::KeySet& keys, const string& s = "",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
class Symbol {
|
||||||
|
Symbol();
|
||||||
|
Symbol(char c, uint64_t j);
|
||||||
|
Symbol(size_t key);
|
||||||
|
|
||||||
|
size_t key() const;
|
||||||
|
void print(const string& s = "") const;
|
||||||
|
bool equals(const gtsam::Symbol& expected, double tol) const;
|
||||||
|
|
||||||
|
char chr() const;
|
||||||
|
uint64_t index() const;
|
||||||
|
string string() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t symbol(char chr, size_t index);
|
||||||
|
char symbolChr(size_t key);
|
||||||
|
size_t symbolIndex(size_t key);
|
||||||
|
|
||||||
|
namespace symbol_shorthand {
|
||||||
|
size_t A(size_t j);
|
||||||
|
size_t B(size_t j);
|
||||||
|
size_t C(size_t j);
|
||||||
|
size_t D(size_t j);
|
||||||
|
size_t E(size_t j);
|
||||||
|
size_t F(size_t j);
|
||||||
|
size_t G(size_t j);
|
||||||
|
size_t H(size_t j);
|
||||||
|
size_t I(size_t j);
|
||||||
|
size_t J(size_t j);
|
||||||
|
size_t K(size_t j);
|
||||||
|
size_t L(size_t j);
|
||||||
|
size_t M(size_t j);
|
||||||
|
size_t N(size_t j);
|
||||||
|
size_t O(size_t j);
|
||||||
|
size_t P(size_t j);
|
||||||
|
size_t Q(size_t j);
|
||||||
|
size_t R(size_t j);
|
||||||
|
size_t S(size_t j);
|
||||||
|
size_t T(size_t j);
|
||||||
|
size_t U(size_t j);
|
||||||
|
size_t V(size_t j);
|
||||||
|
size_t W(size_t j);
|
||||||
|
size_t X(size_t j);
|
||||||
|
size_t Y(size_t j);
|
||||||
|
size_t Z(size_t j);
|
||||||
|
} // namespace symbol_shorthand
|
||||||
|
|
||||||
|
#include <gtsam/inference/LabeledSymbol.h>
|
||||||
|
class LabeledSymbol {
|
||||||
|
LabeledSymbol(size_t full_key);
|
||||||
|
LabeledSymbol(const gtsam::LabeledSymbol& key);
|
||||||
|
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
|
||||||
|
|
||||||
|
size_t key() const;
|
||||||
|
unsigned char label() const;
|
||||||
|
unsigned char chr() const;
|
||||||
|
size_t index() const;
|
||||||
|
|
||||||
|
gtsam::LabeledSymbol upper() const;
|
||||||
|
gtsam::LabeledSymbol lower() const;
|
||||||
|
gtsam::LabeledSymbol newChr(unsigned char c) const;
|
||||||
|
gtsam::LabeledSymbol newLabel(unsigned char label) const;
|
||||||
|
|
||||||
|
void print(string s = "") const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
|
||||||
|
unsigned char mrsymbolChr(size_t key);
|
||||||
|
unsigned char mrsymbolLabel(size_t key);
|
||||||
|
size_t mrsymbolIndex(size_t key);
|
||||||
|
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
class Ordering {
|
||||||
|
/// Type of ordering to use
|
||||||
|
enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM };
|
||||||
|
|
||||||
|
// Standard Constructors and Named Constructors
|
||||||
|
Ordering();
|
||||||
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
|
||||||
|
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
|
||||||
|
gtsam::GaussianFactorGraph}>
|
||||||
|
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
|
||||||
|
|
||||||
|
// Testable
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::Ordering& ord, double tol) const;
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
size_t size() const;
|
||||||
|
size_t at(size_t key) const;
|
||||||
|
void push_back(size_t key);
|
||||||
|
|
||||||
|
// enabling serialization functionality
|
||||||
|
void serialize() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
class DotWriter {
|
||||||
|
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
||||||
|
bool plotFactorPoints = true, bool connectKeysToFactor = true,
|
||||||
|
bool binaryEdges = true);
|
||||||
|
|
||||||
|
double figureWidthInches;
|
||||||
|
double figureHeightInches;
|
||||||
|
bool plotFactorPoints;
|
||||||
|
bool connectKeysToFactor;
|
||||||
|
bool binaryEdges;
|
||||||
|
|
||||||
|
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
|
||||||
|
std::map<char, double> positionHints;
|
||||||
|
std::set<Key> boxes;
|
||||||
|
std::map<size_t, gtsam::Vector2> factorPositions;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/VariableIndex.h>
|
||||||
|
|
||||||
|
// Headers for overloaded methods below, break hierarchy :-/
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
||||||
|
|
||||||
|
class VariableIndex {
|
||||||
|
// Standard Constructors and Named Constructors
|
||||||
|
VariableIndex();
|
||||||
|
// TODO: Templetize constructor when wrap supports it
|
||||||
|
// template<T = {gtsam::FactorGraph}>
|
||||||
|
// VariableIndex(const T& factorGraph, size_t nVariables);
|
||||||
|
// VariableIndex(const T& factorGraph);
|
||||||
|
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
|
||||||
|
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
|
||||||
|
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
|
||||||
|
VariableIndex(const gtsam::VariableIndex& other);
|
||||||
|
|
||||||
|
// Testable
|
||||||
|
bool equals(const gtsam::VariableIndex& other, double tol) const;
|
||||||
|
void print(string s = "VariableIndex: ",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
size_t size() const;
|
||||||
|
size_t nFactors() const;
|
||||||
|
size_t nEntries() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -205,23 +205,5 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void GaussianBayesNet::saveGraph(const std::string& s,
|
|
||||||
const KeyFormatter& keyFormatter) const {
|
|
||||||
std::ofstream of(s.c_str());
|
|
||||||
of << "digraph G{\n";
|
|
||||||
|
|
||||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
|
||||||
typename GaussianConditional::Frontals frontals = conditional->frontals();
|
|
||||||
Key me = frontals.front();
|
|
||||||
typename GaussianConditional::Parents parents = conditional->parents();
|
|
||||||
for (Key p : parents)
|
|
||||||
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
of << "}";
|
|
||||||
of.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -21,17 +21,22 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Gaussian densities */
|
/**
|
||||||
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional>
|
* GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
|
||||||
|
* @addtogroup linear
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef FactorGraph<GaussianConditional> Base;
|
typedef BayesNet<GaussianConditional> Base;
|
||||||
typedef GaussianBayesNet This;
|
typedef GaussianBayesNet This;
|
||||||
typedef GaussianConditional ConditionalType;
|
typedef GaussianConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -45,15 +50,20 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template <typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template <class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit GaussianBayesNet(const CONTAINER& conditionals) {
|
||||||
|
push_back(conditionals);
|
||||||
|
}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
|
* container constructor */
|
||||||
template <class DERIVEDCONDITIONAL>
|
template <class DERIVEDCONDITIONAL>
|
||||||
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~GaussianBayesNet() {}
|
virtual ~GaussianBayesNet() {}
|
||||||
|
@ -66,6 +76,13 @@ namespace gtsam {
|
||||||
/** Check equality */
|
/** Check equality */
|
||||||
bool equals(const This& bn, double tol = 1e-9) const;
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// print graph
|
||||||
|
void print(
|
||||||
|
const std::string& s = "",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
|
@ -180,23 +197,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
|
VectorValues backSubstituteTranspose(const VectorValues& gx) const;
|
||||||
|
|
||||||
/// print graph
|
|
||||||
void print(
|
|
||||||
const std::string& s = "",
|
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
|
||||||
Base::print(s, formatter);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
|
|
||||||
* installed.
|
|
||||||
*
|
|
||||||
* @param s The name of the figure.
|
|
||||||
* @param keyFormatter Formatter to use for styling keys in the graph.
|
|
||||||
*/
|
|
||||||
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
|
|
||||||
DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -437,6 +437,14 @@ class GaussianFactorGraph {
|
||||||
pair<Matrix,Vector> hessian() const;
|
pair<Matrix,Vector> hessian() const;
|
||||||
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
|
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
void serialize() const;
|
void serialize() const;
|
||||||
};
|
};
|
||||||
|
@ -444,11 +452,13 @@ class GaussianFactorGraph {
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
// Constructors
|
// Constructors
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas);
|
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||||
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||||
const gtsam::noiseModel::Diagonal* sigmas);
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
|
||||||
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas);
|
size_t name2, Matrix T,
|
||||||
|
const gtsam::noiseModel::Diagonal* sigmas);
|
||||||
|
|
||||||
// Constructors with no noise model
|
// Constructors with no noise model
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R);
|
GaussianConditional(size_t key, Vector d, Matrix R);
|
||||||
|
@ -461,6 +471,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
|
|
||||||
// Advanced Interface
|
// Advanced Interface
|
||||||
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
||||||
|
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
|
||||||
double logDeterminant() const;
|
double logDeterminant() const;
|
||||||
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
|
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
|
||||||
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
|
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianBayesTree.h>
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
|
|
@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
TEST(GaussianBayesNet, Dot) {
|
||||||
|
GaussianBayesNet fragment;
|
||||||
|
DotWriter writer;
|
||||||
|
writer.variablePositions.emplace(_x_, Vector2(10, 20));
|
||||||
|
writer.variablePositions.emplace(_y_, Vector2(50, 20));
|
||||||
|
|
||||||
|
auto position = writer.variablePos(_x_);
|
||||||
|
CHECK(position);
|
||||||
|
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
||||||
|
|
||||||
|
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" var11[label=\"11\", pos=\"10,20!\"];\n"
|
||||||
|
" var22[label=\"22\", pos=\"50,20!\"];\n"
|
||||||
|
"\n"
|
||||||
|
" var22->var11\n"
|
||||||
|
"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
||||||
min.y() = std::numeric_limits<double>::infinity();
|
min.y() = std::numeric_limits<double>::infinity();
|
||||||
for (const Key& key : keys) {
|
for (const Key& key : keys) {
|
||||||
if (values.exists(key)) {
|
if (values.exists(key)) {
|
||||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||||
if (xy) {
|
if (xy) {
|
||||||
if (xy->x() < min.x()) min.x() = xy->x();
|
if (xy->x() < min.x()) min.x() = xy->x();
|
||||||
if (xy->y() < min.y()) min.y() = xy->y();
|
if (xy->y() < min.y()) min.y() = xy->y();
|
||||||
|
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
|
||||||
return min;
|
return min;
|
||||||
}
|
}
|
||||||
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::operator()(
|
boost::optional<Vector2> GraphvizFormatting::extractPosition(
|
||||||
const Value& value) const {
|
const Value& value) const {
|
||||||
Vector3 t;
|
Vector3 t;
|
||||||
if (const GenericValue<Pose2>* p =
|
if (const GenericValue<Pose2>* p =
|
||||||
|
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||||
return Vector2(x, y);
|
return Vector2(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affinely transformed variable position if it exists.
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||||
const Vector2& min,
|
const Vector2& min,
|
||||||
Key key) const {
|
Key key) const {
|
||||||
if (!values.exists(key)) return boost::none;
|
if (!values.exists(key)) return DotWriter::variablePos(key);
|
||||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
boost::optional<Vector2> xy = extractPosition(values.at(key));
|
||||||
if (xy) {
|
if (xy) {
|
||||||
xy->x() = scale * (xy->x() - min.x());
|
xy->x() = scale * (xy->x() - min.x());
|
||||||
xy->y() = scale * (xy->y() - min.y());
|
xy->y() = scale * (xy->y() - min.y());
|
||||||
|
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||||
return xy;
|
return xy;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return affinely transformed factor position if it exists.
|
|
||||||
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
||||||
size_t i) const {
|
size_t i) const {
|
||||||
if (factorPositions.size() == 0) return boost::none;
|
if (factorPositions.size() == 0) return boost::none;
|
||||||
|
|
|
@ -41,9 +41,6 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
||||||
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
||||||
///< connectivity
|
///< connectivity
|
||||||
|
|
||||||
/// (optional for each factor) Manually specify factor "dot" positions:
|
|
||||||
std::map<size_t, Vector2> factorPositions;
|
|
||||||
|
|
||||||
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
||||||
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
||||||
GraphvizFormatting()
|
GraphvizFormatting()
|
||||||
|
@ -56,7 +53,7 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
||||||
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
||||||
|
|
||||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||||
boost::optional<Vector2> operator()(const Value& value) const;
|
boost::optional<Vector2> extractPosition(const Value& value) const;
|
||||||
|
|
||||||
/// Return affinely transformed variable position if it exists.
|
/// Return affinely transformed variable position if it exists.
|
||||||
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
||||||
|
|
|
@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
|
||||||
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const GraphvizFormatting& writer) const {
|
const GraphvizFormatting& writer) const {
|
||||||
writer.writePreamble(&os);
|
writer.graphPreamble(&os);
|
||||||
|
|
||||||
// Find bounds (imperative)
|
// Find bounds (imperative)
|
||||||
KeySet keys = this->keys();
|
KeySet keys = this->keys();
|
||||||
|
@ -111,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
// Create nodes for each variable in the graph
|
// Create nodes for each variable in the graph
|
||||||
for (Key key : keys) {
|
for (Key key : keys) {
|
||||||
auto position = writer.variablePos(values, min, key);
|
auto position = writer.variablePos(values, min, key);
|
||||||
writer.DrawVariable(key, keyFormatter, position, &os);
|
writer.drawVariable(key, keyFormatter, position, &os);
|
||||||
}
|
}
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
|
|
@ -43,12 +43,14 @@ namespace gtsam {
|
||||||
class ExpressionFactor;
|
class ExpressionFactor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
|
* A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||||
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
|
* which derive from NonlinearFactor. The values structures are typically (in
|
||||||
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds.
|
* SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
|
||||||
* Linearizing the non-linear factor graph creates a linear factor graph on the
|
* in non-linear manifolds. Linearizing the non-linear factor graph creates a
|
||||||
* tangent vector space at the linearization point. Because the tangent space is a true
|
* linear factor graph on the tangent vector space at the linearization point.
|
||||||
* vector space, the config type will be an VectorValues in that linearized factor graph.
|
* Because the tangent space is a true vector space, the config type will be
|
||||||
|
* an VectorValues in that linearized factor graph.
|
||||||
|
* @addtogroup nonlinear
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
||||||
|
|
||||||
|
@ -58,6 +60,9 @@ namespace gtsam {
|
||||||
typedef NonlinearFactorGraph This;
|
typedef NonlinearFactorGraph This;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
NonlinearFactorGraph() {}
|
NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
@ -76,6 +81,10 @@ namespace gtsam {
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~NonlinearFactorGraph() {}
|
virtual ~NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(
|
void print(
|
||||||
const std::string& str = "NonlinearFactorGraph: ",
|
const std::string& str = "NonlinearFactorGraph: ",
|
||||||
|
@ -90,6 +99,10 @@ namespace gtsam {
|
||||||
/** Test equality */
|
/** Test equality */
|
||||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
||||||
double error(const Values& values) const;
|
double error(const Values& values) const;
|
||||||
|
|
||||||
|
@ -206,6 +219,7 @@ namespace gtsam {
|
||||||
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
/// @name Graph Display
|
/// @name Graph Display
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -215,20 +229,19 @@ namespace gtsam {
|
||||||
/// Output to graphviz format, stream version, with Values/extra options.
|
/// Output to graphviz format, stream version, with Values/extra options.
|
||||||
void dot(std::ostream& os, const Values& values,
|
void dot(std::ostream& os, const Values& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
GraphvizFormatting()) const;
|
|
||||||
|
|
||||||
/// Output to graphviz format string, with Values/extra options.
|
/// Output to graphviz format string, with Values/extra options.
|
||||||
std::string dot(const Values& values,
|
std::string dot(
|
||||||
|
const Values& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
GraphvizFormatting()) const;
|
|
||||||
|
|
||||||
/// output to file with graphviz format, with Values/extra options.
|
/// output to file with graphviz format, with Values/extra options.
|
||||||
void saveGraph(const std::string& filename, const Values& values,
|
void saveGraph(
|
||||||
|
const std::string& filename, const Values& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& graphvizFormatting =
|
const GraphvizFormatting& writer = GraphvizFormatting()) const;
|
||||||
GraphvizFormatting()) const;
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -251,6 +264,8 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated
|
||||||
|
/// @{
|
||||||
/** @deprecated */
|
/** @deprecated */
|
||||||
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
||||||
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
||||||
|
@ -275,6 +290,7 @@ namespace gtsam {
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -23,121 +23,9 @@ namespace gtsam {
|
||||||
#include <gtsam/geometry/SOn.h>
|
#include <gtsam/geometry/SOn.h>
|
||||||
#include <gtsam/geometry/StereoPoint2.h>
|
#include <gtsam/geometry/StereoPoint2.h>
|
||||||
#include <gtsam/geometry/Unit3.h>
|
#include <gtsam/geometry/Unit3.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
|
||||||
#include <gtsam/navigation/ImuBias.h>
|
#include <gtsam/navigation/ImuBias.h>
|
||||||
#include <gtsam/navigation/NavState.h>
|
#include <gtsam/navigation/NavState.h>
|
||||||
|
|
||||||
class Symbol {
|
|
||||||
Symbol();
|
|
||||||
Symbol(char c, uint64_t j);
|
|
||||||
Symbol(size_t key);
|
|
||||||
|
|
||||||
size_t key() const;
|
|
||||||
void print(const string& s = "") const;
|
|
||||||
bool equals(const gtsam::Symbol& expected, double tol) const;
|
|
||||||
|
|
||||||
char chr() const;
|
|
||||||
uint64_t index() const;
|
|
||||||
string string() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t symbol(char chr, size_t index);
|
|
||||||
char symbolChr(size_t key);
|
|
||||||
size_t symbolIndex(size_t key);
|
|
||||||
|
|
||||||
namespace symbol_shorthand {
|
|
||||||
size_t A(size_t j);
|
|
||||||
size_t B(size_t j);
|
|
||||||
size_t C(size_t j);
|
|
||||||
size_t D(size_t j);
|
|
||||||
size_t E(size_t j);
|
|
||||||
size_t F(size_t j);
|
|
||||||
size_t G(size_t j);
|
|
||||||
size_t H(size_t j);
|
|
||||||
size_t I(size_t j);
|
|
||||||
size_t J(size_t j);
|
|
||||||
size_t K(size_t j);
|
|
||||||
size_t L(size_t j);
|
|
||||||
size_t M(size_t j);
|
|
||||||
size_t N(size_t j);
|
|
||||||
size_t O(size_t j);
|
|
||||||
size_t P(size_t j);
|
|
||||||
size_t Q(size_t j);
|
|
||||||
size_t R(size_t j);
|
|
||||||
size_t S(size_t j);
|
|
||||||
size_t T(size_t j);
|
|
||||||
size_t U(size_t j);
|
|
||||||
size_t V(size_t j);
|
|
||||||
size_t W(size_t j);
|
|
||||||
size_t X(size_t j);
|
|
||||||
size_t Y(size_t j);
|
|
||||||
size_t Z(size_t j);
|
|
||||||
} // namespace symbol_shorthand
|
|
||||||
|
|
||||||
// Default keyformatter
|
|
||||||
void PrintKeyList(
|
|
||||||
const gtsam::KeyList& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
void PrintKeyVector(
|
|
||||||
const gtsam::KeyVector& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
void PrintKeySet(
|
|
||||||
const gtsam::KeySet& keys, const string& s = "",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
|
||||||
|
|
||||||
#include <gtsam/inference/LabeledSymbol.h>
|
|
||||||
class LabeledSymbol {
|
|
||||||
LabeledSymbol(size_t full_key);
|
|
||||||
LabeledSymbol(const gtsam::LabeledSymbol& key);
|
|
||||||
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
|
|
||||||
|
|
||||||
size_t key() const;
|
|
||||||
unsigned char label() const;
|
|
||||||
unsigned char chr() const;
|
|
||||||
size_t index() const;
|
|
||||||
|
|
||||||
gtsam::LabeledSymbol upper() const;
|
|
||||||
gtsam::LabeledSymbol lower() const;
|
|
||||||
gtsam::LabeledSymbol newChr(unsigned char c) const;
|
|
||||||
gtsam::LabeledSymbol newLabel(unsigned char label) const;
|
|
||||||
|
|
||||||
void print(string s = "") const;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
|
|
||||||
unsigned char mrsymbolChr(size_t key);
|
|
||||||
unsigned char mrsymbolLabel(size_t key);
|
|
||||||
size_t mrsymbolIndex(size_t key);
|
|
||||||
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
class Ordering {
|
|
||||||
/// Type of ordering to use
|
|
||||||
enum OrderingType {
|
|
||||||
COLAMD, METIS, NATURAL, CUSTOM
|
|
||||||
};
|
|
||||||
|
|
||||||
// Standard Constructors and Named Constructors
|
|
||||||
Ordering();
|
|
||||||
Ordering(const gtsam::Ordering& other);
|
|
||||||
|
|
||||||
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
|
|
||||||
gtsam::GaussianFactorGraph}>
|
|
||||||
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
bool equals(const gtsam::Ordering& ord, double tol) const;
|
|
||||||
|
|
||||||
// Standard interface
|
|
||||||
size_t size() const;
|
|
||||||
size_t at(size_t key) const;
|
|
||||||
void push_back(size_t key);
|
|
||||||
|
|
||||||
// enabling serialization functionality
|
|
||||||
void serialize() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
#include <gtsam/nonlinear/GraphvizFormatting.h>
|
#include <gtsam/nonlinear/GraphvizFormatting.h>
|
||||||
class GraphvizFormatting : gtsam::DotWriter {
|
class GraphvizFormatting : gtsam::DotWriter {
|
||||||
GraphvizFormatting();
|
GraphvizFormatting();
|
||||||
|
@ -207,18 +95,17 @@ class NonlinearFactorGraph {
|
||||||
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
|
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
|
||||||
gtsam::NonlinearFactorGraph clone() const;
|
gtsam::NonlinearFactorGraph clone() const;
|
||||||
|
|
||||||
// enabling serialization functionality
|
|
||||||
void serialize() const;
|
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::Values& values,
|
const gtsam::Values& values,
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& writer = GraphvizFormatting());
|
const GraphvizFormatting& formatting = GraphvizFormatting());
|
||||||
void saveGraph(const string& s, const gtsam::Values& values,
|
void saveGraph(
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const string& s, const gtsam::Values& values,
|
||||||
gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
const GraphvizFormatting& writer =
|
const GraphvizFormatting& formatting = GraphvizFormatting()) const;
|
||||||
GraphvizFormatting()) const;
|
|
||||||
|
// enabling serialization functionality
|
||||||
|
void serialize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
|
@ -16,12 +16,8 @@
|
||||||
* @author Richard Roberts
|
* @author Richard Roberts
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||||
#include <boost/range/adaptor/reversed.hpp>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -29,28 +25,7 @@ namespace gtsam {
|
||||||
template class FactorGraph<SymbolicConditional>;
|
template class FactorGraph<SymbolicConditional>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool SymbolicBayesNet::equals(const This& bn, double tol) const
|
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
|
||||||
{
|
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
} // namespace gtsam
|
||||||
/* ************************************************************************* */
|
|
||||||
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
|
|
||||||
{
|
|
||||||
std::ofstream of(s.c_str());
|
|
||||||
of << "digraph G{\n";
|
|
||||||
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this)) {
|
|
||||||
SymbolicConditional::Frontals frontals = conditional->frontals();
|
|
||||||
Key me = frontals.front();
|
|
||||||
SymbolicConditional::Parents parents = conditional->parents();
|
|
||||||
for(Key p: parents)
|
|
||||||
of << p << "->" << me << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
of << "}";
|
|
||||||
of.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -19,19 +19,19 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** Symbolic Bayes Net
|
/**
|
||||||
* \nosubgrouping
|
* A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
|
||||||
|
* @addtogroup symbolic
|
||||||
*/
|
*/
|
||||||
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> {
|
class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
typedef BayesNet<SymbolicConditional> Base;
|
||||||
typedef FactorGraph<SymbolicConditional> Base;
|
|
||||||
typedef SymbolicBayesNet This;
|
typedef SymbolicBayesNet This;
|
||||||
typedef SymbolicConditional ConditionalType;
|
typedef SymbolicConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -45,15 +45,20 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
template <typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
|
||||||
|
: Base(firstConditional, lastConditional) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template <class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
explicit SymbolicBayesNet(const CONTAINER& conditionals) {
|
||||||
|
push_back(conditionals);
|
||||||
|
}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template
|
||||||
|
* container constructor */
|
||||||
template <class DERIVEDCONDITIONAL>
|
template <class DERIVEDCONDITIONAL>
|
||||||
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
|
||||||
|
: Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~SymbolicBayesNet() {}
|
virtual ~SymbolicBayesNet() {}
|
||||||
|
@ -75,13 +80,6 @@ namespace gtsam {
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// @name Standard Interface
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
@ -3,11 +3,6 @@
|
||||||
//*************************************************************************
|
//*************************************************************************
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
|
||||||
|
|
||||||
// ###################
|
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactor.h>
|
#include <gtsam/symbolic/SymbolicFactor.h>
|
||||||
virtual class SymbolicFactor {
|
virtual class SymbolicFactor {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
|
@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph {
|
||||||
const gtsam::KeyVector& key_vector,
|
const gtsam::KeyVector& key_vector,
|
||||||
const gtsam::Ordering& marginalizedVariableOrdering);
|
const gtsam::Ordering& marginalizedVariableOrdering);
|
||||||
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
|
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
|
||||||
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
|
bool equals(const gtsam::SymbolicConditional& other, double tol) const;
|
||||||
|
|
||||||
// Standard interface
|
// Standard interface
|
||||||
|
gtsam::Key firstFrontalKey() const;
|
||||||
size_t nrFrontals() const;
|
size_t nrFrontals() const;
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
};
|
};
|
||||||
|
@ -125,6 +129,14 @@ class SymbolicBayesNet {
|
||||||
gtsam::SymbolicConditional* back() const;
|
gtsam::SymbolicConditional* back() const;
|
||||||
void push_back(gtsam::SymbolicConditional* conditional);
|
void push_back(gtsam::SymbolicConditional* conditional);
|
||||||
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
|
void push_back(const gtsam::SymbolicBayesNet& bayesNet);
|
||||||
|
|
||||||
|
string dot(
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
|
void saveGraph(
|
||||||
|
string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
||||||
|
@ -173,29 +185,4 @@ class SymbolicBayesTreeClique {
|
||||||
void deleteCachedShortcuts();
|
void deleteCachedShortcuts();
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/VariableIndex.h>
|
|
||||||
class VariableIndex {
|
|
||||||
// Standard Constructors and Named Constructors
|
|
||||||
VariableIndex();
|
|
||||||
// TODO: Templetize constructor when wrap supports it
|
|
||||||
// template<T = {gtsam::FactorGraph}>
|
|
||||||
// VariableIndex(const T& factorGraph, size_t nVariables);
|
|
||||||
// VariableIndex(const T& factorGraph);
|
|
||||||
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
|
|
||||||
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
|
|
||||||
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
|
|
||||||
VariableIndex(const gtsam::VariableIndex& other);
|
|
||||||
|
|
||||||
// Testable
|
|
||||||
bool equals(const gtsam::VariableIndex& other, double tol) const;
|
|
||||||
void print(string s = "VariableIndex: ",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
|
|
||||||
// Standard interface
|
|
||||||
size_t size() const;
|
|
||||||
size_t nFactors() const;
|
|
||||||
size_t nEntries() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -15,13 +15,16 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/base/VectorSpace.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <boost/make_shared.hpp>
|
||||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -30,7 +33,6 @@ static const Key _L_ = 0;
|
||||||
static const Key _A_ = 1;
|
static const Key _A_ = 1;
|
||||||
static const Key _B_ = 2;
|
static const Key _B_ = 2;
|
||||||
static const Key _C_ = 3;
|
static const Key _C_ = 3;
|
||||||
static const Key _D_ = 4;
|
|
||||||
|
|
||||||
static SymbolicConditional::shared_ptr
|
static SymbolicConditional::shared_ptr
|
||||||
B(new SymbolicConditional(_B_)),
|
B(new SymbolicConditional(_B_)),
|
||||||
|
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(SymbolicBayesNet, saveGraph) {
|
TEST(SymbolicBayesNet, Dot) {
|
||||||
|
using symbol_shorthand::A;
|
||||||
|
using symbol_shorthand::X;
|
||||||
SymbolicBayesNet bn;
|
SymbolicBayesNet bn;
|
||||||
bn += SymbolicConditional(_A_, _B_);
|
bn += SymbolicConditional(X(3), X(2), A(2));
|
||||||
KeyVector keys {_B_, _C_, _D_};
|
bn += SymbolicConditional(X(2), X(1), A(1));
|
||||||
bn += SymbolicConditional::FromKeys(keys,2);
|
bn += SymbolicConditional(X(1));
|
||||||
bn += SymbolicConditional(_D_);
|
|
||||||
|
|
||||||
bn.saveGraph("SymbolicBayesNet.dot");
|
DotWriter writer;
|
||||||
|
writer.positionHints.emplace('a', 2);
|
||||||
|
writer.positionHints.emplace('x', 1);
|
||||||
|
writer.boxes.emplace(A(1));
|
||||||
|
writer.boxes.emplace(A(2));
|
||||||
|
|
||||||
|
auto position = writer.variablePos(A(1));
|
||||||
|
CHECK(position);
|
||||||
|
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
|
||||||
|
|
||||||
|
string actual = bn.dot(DefaultKeyFormatter, writer);
|
||||||
|
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
|
||||||
|
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
|
||||||
|
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
|
||||||
|
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
|
||||||
|
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
|
||||||
|
"\n"
|
||||||
|
" varx1->varx2\n"
|
||||||
|
" vara1->varx2\n"
|
||||||
|
" varx2->varx3\n"
|
||||||
|
" vara2->varx3\n"
|
||||||
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -54,6 +54,7 @@ set(ignore
|
||||||
set(interface_headers
|
set(interface_headers
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
|
${PROJECT_SOURCE_DIR}/gtsam/base/base.i
|
||||||
|
${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
|
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
|
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
|
||||||
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
|
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
/* Please refer to:
|
||||||
|
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||||
|
* These are required to save one copy operation on Python calls.
|
||||||
|
*
|
||||||
|
* NOTES
|
||||||
|
* =================
|
||||||
|
*
|
||||||
|
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
|
||||||
|
* automatic STL binding, such that the raw objects can be accessed in Python.
|
||||||
|
* Without this they will be automatically converted to a Python object, and all
|
||||||
|
* mutations on Python side will not be reflected on C++.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
/* Please refer to:
|
||||||
|
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
|
||||||
|
* These are required to save one copy operation on Python calls.
|
||||||
|
*
|
||||||
|
* NOTES
|
||||||
|
* =================
|
||||||
|
*
|
||||||
|
* `py::bind_vector` and similar machinery gives the std container a Python-like
|
||||||
|
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
|
||||||
|
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
|
||||||
|
* and saves one copy operation.
|
||||||
|
*/
|
||||||
|
|
|
@ -78,7 +78,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
||||||
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
|
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
|
||||||
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
|
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
|
||||||
self.assertEqual(self.graph.dot(self.values,
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
writer=graphviz_formatting),
|
formatting=graphviz_formatting),
|
||||||
textwrap.dedent(expected_result))
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
def test_factor_points(self):
|
def test_factor_points(self):
|
||||||
|
@ -100,7 +100,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
||||||
graphviz_formatting.plotFactorPoints = False
|
graphviz_formatting.plotFactorPoints = False
|
||||||
|
|
||||||
self.assertEqual(self.graph.dot(self.values,
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
writer=graphviz_formatting),
|
formatting=graphviz_formatting),
|
||||||
textwrap.dedent(expected_result))
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
def test_width_height(self):
|
def test_width_height(self):
|
||||||
|
@ -127,7 +127,7 @@ class TestGraphvizFormatting(GtsamTestCase):
|
||||||
graphviz_formatting.figureHeightInches = 10
|
graphviz_formatting.figureHeightInches = 10
|
||||||
|
|
||||||
self.assertEqual(self.graph.dot(self.values,
|
self.assertEqual(self.graph.dot(self.values,
|
||||||
writer=graphviz_formatting),
|
formatting=graphviz_formatting),
|
||||||
textwrap.dedent(expected_result))
|
textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue