Merge pull request #971 from borglab/feature/notebook_dot

release/4.3a0
Frank Dellaert 2021-12-20 22:43:54 -05:00 committed by GitHub
commit 168a67da05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1207 additions and 290 deletions

View File

@ -173,7 +173,7 @@ TEST(Matrix, stack )
{
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
Matrix AB = stack(2, &A, &B);
Matrix AB = gtsam::stack(2, &A, &B);
Matrix C(5, 2);
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
@ -187,7 +187,7 @@ TEST(Matrix, stack )
std::vector<gtsam::Matrix> matrices;
matrices.push_back(A);
matrices.push_back(B);
Matrix AB2 = stack(matrices);
Matrix AB2 = gtsam::stack(matrices);
EQUALITY(C,AB2);
}

View File

@ -248,8 +248,9 @@ namespace gtsam {
void dot(std::ostream& os, bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n";
for (size_t i = 0; i < branches_.size(); i++) {
NodePtr branch = branches_[i];
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
const NodePtr& branch = branches_[i];
// Check if zero
if (!showZero) {
@ -258,8 +259,10 @@ namespace gtsam {
}
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
if (B == 2) {
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
}
os << std::endl;
branch->dot(os, showZero);
}
@ -671,7 +674,14 @@ namespace gtsam {
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
}
}
template<typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const {
std::stringstream ss;
dot(ss, showZero);
return ss.str();
}
/*********************************************************************************/

View File

@ -198,6 +198,9 @@ namespace gtsam {
/** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const;
/** output to graphviz format string */
std::string dot(bool showZero = true) const;
/// @name Advanced Interface
/// @{

View File

@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
string dot(bool showZero = false) const;
};
#include <gtsam/discrete/DiscreteConditional.h>
@ -52,8 +52,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
size_t size() const; // TODO(dellaert): why do I have to repeat???
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
@ -82,6 +80,8 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
@ -98,9 +98,19 @@ class DiscreteBayesTree {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter();
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph {
DiscreteFactorGraph();
@ -117,7 +127,15 @@ class DiscreteFactorGraph {
void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
string dot(const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;

View File

@ -135,7 +135,7 @@ TEST(DiscreteBayesNet, Asia) {
}
/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
TEST(DiscreteBayesNet, Sugar) {
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
DiscreteBayesNet bn;
@ -149,6 +149,29 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
bn.add(C | S = "1/1/2 5/2/3");
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);
DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50");
fragment.add(Tuberculosis | Asia = "99/1 95/5");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot();
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -26,76 +26,89 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <vector>
using namespace std;
using namespace gtsam;
static bool debug = false;
static constexpr bool debug = false;
/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteBayesNet bayesNet;
bayesNet.add(key[14] % "1/3");
boost::shared_ptr<DiscreteBayesTree> bayesTree;
bayesNet.add(key[13] | key[14] = "1/3 3/1");
bayesNet.add(key[12] | key[14] = "3/1 3/1");
/**
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
* and then create the Bayes tree from it.
*/
TestFixture() {
// Define variables.
for (int i = 0; i < 15; i++) {
DiscreteKey key_i(i, 2);
keys.push_back(key_i);
}
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
// Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3");
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
// Create a BayesTree out of the Bayes net.
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
}
};
/* ************************************************************************* */
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
if (debug) {
GTSAM_PRINT(bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
GTSAM_PRINT(self.bayesNet);
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
if (debug) {
GTSAM_PRINT(*bayesTree);
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
GTSAM_PRINT(*self.bayesTree);
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
}
// Check frontals and parents
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
auto clique_i = (*bayesTree)[i];
auto clique_i = (*self.bayesTree)[i];
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
}
auto R = bayesTree->roots().front();
auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations
vector<DiscreteValues> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
vector<DiscreteValues> allPosbValues =
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double expected = bayesNet.evaluate(x);
double actual = bayesTree->evaluate(x);
double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double px = bayesTree->evaluate(x);
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteValues all1 = allPosbValues.back();
// check separator marginal P(S0)
auto clique = (*bayesTree)[0];
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*bayesTree)[11];
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root
clique = (*bayesTree)[9];
clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S8||R) to root
clique = (*bayesTree)[8];
clique = (*self.bayesTree)[8];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S2||R) to root
clique = (*bayesTree)[2];
clique = (*self.bayesTree)[2];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// check shortcut P(S0||R) to root
clique = (*bayesTree)[0];
clique = (*self.bayesTree)[0];
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2)
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
// Check joint P(1, 2)
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
// Check joint P(2, 4)
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 5)
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 6)
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
// Check joint P(4, 11)
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13,11,6,7\"];\n"
"0->1\n"
"1[label=\"14 : 11,13\"];\n"
"1->2\n"
"2[label=\"9,12 : 14\"];\n"
"2->3\n"
"3[label=\"3 : 9,12\"];\n"
"2->4\n"
"4[label=\"2 : 9,12\"];\n"
"2->5\n"
"5[label=\"8 : 12,14\"];\n"
"5->6\n"
"6[label=\"1 : 8,12\"];\n"
"5->7\n"
"7[label=\"0 : 8,12\"];\n"
"1->8\n"
"8[label=\"10 : 13,14\"];\n"
"8->9\n"
"9[label=\"5 : 10,13\"];\n"
"8->10\n"
"10[label=\"4 : 10,13\"];\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -359,6 +359,31 @@ cout << unicorns;
}
#endif
/* ************************************************************************* */
TEST(DiscreteFactorGraph, Dot) {
// Declare a bunch of keys
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph
DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6");
string actual = graph.dot();
string expected =
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"0\"];\n"
" var1[label=\"1\"];\n"
" var2[label=\"2\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
"}\n";
EXPECT(actual == expected);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -35,21 +35,39 @@ void BayesNet<CONDITIONAL>::print(
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::ofstream of(s.c_str());
of << "digraph G{\n";
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
os << "digraph G{\n";
for (auto conditional : boost::adaptors::reverse(*this)) {
typename CONDITIONAL::Frontals frontals = conditional->frontals();
Key me = frontals.front();
typename CONDITIONAL::Parents parents = conditional->parents();
for (Key p : parents)
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
for (auto conditional : *this) {
auto frontals = conditional->frontals();
const Key me = frontals.front();
auto parents = conditional->parents();
for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
}
of << "}";
os << "}";
std::flush(os);
}
/* ************************************************************************* */
template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
of.close();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -64,11 +64,21 @@ namespace gtsam {
/// @}
/// @name Standard Interface
/// @name Graph Display
/// @{
void saveGraph(const std::string& s,
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
}

View File

@ -63,20 +63,40 @@ namespace gtsam {
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
std::ofstream of(s.c_str());
of<< "digraph G{\n";
for(const sharedClique& root: roots_)
saveGraph(of, root, keyFormatter);
of<<"}";
template <class CLIQUE>
void BayesTree<CLIQUE>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const {
if (roots_.empty())
throw std::invalid_argument(
"the root of Bayes tree has not been initialized!");
os << "digraph G{\n";
for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
os << "}";
std::flush(os);
}
/* ************************************************************************* */
template <class CLIQUE>
std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
template <class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const {
std::ofstream of(filename.c_str());
dot(of, keyFormatter);
of.close();
}
/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
template <class CLIQUE>
void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
const KeyFormatter& indexFormatter,
int parentnum) const {
static int num = 0;
bool first = true;
std::stringstream out;
@ -107,7 +127,7 @@ namespace gtsam {
for (sharedClique c : clique->children) {
num++;
saveGraph(s, c, indexFormatter, parentnum);
dot(s, c, indexFormatter, parentnum);
}
}

View File

@ -182,13 +182,20 @@ namespace gtsam {
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
/**
* Read only with side effects
*/
/// @name Graph Display
/// @{
/** saves the Tree to a text file in GraphViz format */
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
/// @name Advanced Interface
/// @{
@ -236,8 +243,8 @@ namespace gtsam {
protected:
/** private helper method for saving the Tree to a text file in GraphViz format */
void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
int parentnum = 0) const;
void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
int parentnum = 0) const;
/** Gather data on a single clique */
void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const;
@ -249,7 +256,7 @@ namespace gtsam {
void fillNodesIndex(const sharedClique& subtree);
// Friend JunctionTree because it directly fills roots and nodes index.
template<class BAYESRTEE, class GRAPH> friend class EliminatableClusterTree;
template<class BAYESTREE, class GRAPH> friend class EliminatableClusterTree;
private:
/** Serialization function */

View File

@ -0,0 +1,93 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DotWriter.cpp
* @brief Graphviz formatting for factor graphs.
* @author Frank Dellaert
* @date December, 2021
*/
#include <gtsam/base/Vector.h>
#include <gtsam/inference/DotWriter.h>
#include <ostream>
using namespace std;
namespace gtsam {
void DotWriter::writePreamble(ostream* os) const {
*os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
ostream* os) {
// Label the node with the label from the KeyFormatter
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
}
*os << "];\n";
}
void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
ostream* os) {
*os << " factor" << i << "[label=\"\", shape=point";
if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
}
*os << "];\n";
}
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) {
*os << " var" << key1 << "--"
<< "var" << key2 << ";\n";
}
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) {
*os << " var" << key << "--"
<< "factor" << i << ";\n";
}
void DotWriter::processFactor(size_t i, const KeyVector& keys,
const boost::optional<Vector2>& position,
ostream* os) const {
if (plotFactorPoints) {
if (binaryEdges && keys.size() == 2) {
ConnectVariables(keys[0], keys[1], os);
} else {
// Create dot for the factor.
DrawFactor(i, position, os);
// Make factor-variable connections
if (connectKeysToFactor) {
for (Key key : keys) {
ConnectVariableFactor(key, i, os);
}
}
}
} else {
// just connect variables in a clique
for (Key key1 : keys) {
for (Key key2 : keys) {
if (key2 > key1) {
ConnectVariables(key1, key2, os);
}
}
}
}
}
} // namespace gtsam

View File

@ -0,0 +1,69 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DotWriter.h
* @brief Graphviz formatter
* @author Frank Dellaert
* @date December, 2021
*/
#pragma once
#include <gtsam/base/FastVector.h>
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Key.h>
#include <iosfwd>
namespace gtsam {
/// Graphviz formatter.
struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches
bool plotFactorPoints; ///< Plots each factor as a dot between the variables
bool connectKeysToFactor; ///< Draw a line from each key within a factor to
///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors
DotWriter()
: figureWidthInches(5),
figureHeightInches(5),
plotFactorPoints(true),
connectKeysToFactor(true),
binaryEdges(true) {}
/// Write out preamble, including size.
void writePreamble(std::ostream* os) const;
/// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position,
std::ostream* os);
/// Create factor dot.
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os);
/// Connect two variables.
static void ConnectVariables(Key key1, Key key2, std::ostream* os);
/// Connect variable and factor.
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
/// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys,
const boost::optional<Vector2>& position,
std::ostream* os) const;
};
} // namespace gtsam

View File

@ -26,6 +26,7 @@
#include <stdio.h>
#include <algorithm>
#include <iostream> // for cout :-(
#include <fstream>
#include <sstream>
#include <string>
@ -125,4 +126,48 @@ FactorIndices FactorGraph<FACTOR>::add_factors(const CONTAINER& factors,
return newFactorIndices;
}
/* ************************************************************************* */
template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os, const DotWriter& writer,
const KeyFormatter& keyFormatter) const {
writer.writePreamble(&os);
// Create nodes for each variable in the graph
for (Key key : keys()) {
writer.DrawVariable(key, keyFormatter, boost::none, &os);
}
os << "\n";
// Create factors and variable connections
for (size_t i = 0; i < size(); ++i) {
const auto& factor = at(i);
if (factor) {
const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, boost::none, &os);
}
}
os << "}\n";
std::flush(os);
}
/* ************************************************************************* */
template <class FACTOR>
std::string FactorGraph<FACTOR>::dot(const DotWriter& writer,
const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, writer, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
template <class FACTOR>
void FactorGraph<FACTOR>::saveGraph(const std::string& filename,
const DotWriter& writer,
const KeyFormatter& keyFormatter) const {
std::ofstream of(filename.c_str());
dot(of, writer, keyFormatter);
of.close();
}
} // namespace gtsam

View File

@ -22,9 +22,10 @@
#pragma once
#include <gtsam/inference/DotWriter.h>
#include <gtsam/inference/Key.h>
#include <gtsam/base/FastVector.h>
#include <gtsam/base/Testable.h>
#include <gtsam/inference/Key.h>
#include <Eigen/Core> // for Eigen::aligned_allocator
@ -36,6 +37,7 @@
#include <string>
#include <type_traits>
#include <utility>
#include <iosfwd>
namespace gtsam {
/// Define collection type:
@ -371,6 +373,23 @@ class FactorGraph {
return factors_.erase(first, last);
}
/// @}
/// @name Graph Display
/// @{
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const DotWriter& writer = DotWriter(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(const DotWriter& writer = DotWriter(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const DotWriter& writer = DotWriter(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
/// @name Advanced Interface
/// @{

View File

@ -379,7 +379,7 @@ namespace gtsam {
gttic(Compute_minimizing_step_size);
// Compute minimizing step size
double step = -gradientSqNorm / dot(Rg, Rg);
double step = -gradientSqNorm / gtsam::dot(Rg, Rg);
gttoc(Compute_minimizing_step_size);
gttic(Compute_point);

View File

@ -0,0 +1,136 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file GraphvizFormatting.cpp
* @brief Graphviz formatter for NonlinearFactorGraph
* @author Frank Dellaert
* @date December, 2021
*/
#include <gtsam/nonlinear/GraphvizFormatting.h>
#include <gtsam/nonlinear/Values.h>
// TODO(frank): nonlinear should not depend on geometry:
#include <gtsam/geometry/Pose2.h>
#include <gtsam/geometry/Pose3.h>
#include <limits>
namespace gtsam {
Vector2 GraphvizFormatting::findBounds(const Values& values,
const KeySet& keys) const {
Vector2 min;
min.x() = std::numeric_limits<double>::infinity();
min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) {
if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key));
if (xy) {
if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y();
}
}
}
return min;
}
boost::optional<Vector2> GraphvizFormatting::operator()(
const Value& value) const {
Vector3 t;
if (const GenericValue<Pose2>* p =
dynamic_cast<const GenericValue<Pose2>*>(&value)) {
t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Vector2>* p =
dynamic_cast<const GenericValue<Vector2>*>(&value)) {
t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Pose3>* p =
dynamic_cast<const GenericValue<Pose3>*>(&value)) {
t = p->value().translation();
} else if (const GenericValue<Point3>* p =
dynamic_cast<const GenericValue<Point3>*>(&value)) {
t = p->value();
} else {
return boost::none;
}
double x, y;
switch (paperHorizontalAxis) {
case X:
x = t.x();
break;
case Y:
x = t.y();
break;
case Z:
x = t.z();
break;
case NEGX:
x = -t.x();
break;
case NEGY:
x = -t.y();
break;
case NEGZ:
x = -t.z();
break;
default:
throw std::runtime_error("Invalid enum value");
}
switch (paperVerticalAxis) {
case X:
y = t.x();
break;
case Y:
y = t.y();
break;
case Z:
y = t.z();
break;
case NEGX:
y = -t.x();
break;
case NEGY:
y = -t.y();
break;
case NEGZ:
y = -t.z();
break;
default:
throw std::runtime_error("Invalid enum value");
}
return Vector2(x, y);
}
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min,
Key key) const {
if (!values.exists(key)) return boost::none;
boost::optional<Vector2> xy = operator()(values.at(key));
if (xy) {
xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y());
}
return xy;
}
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const {
if (factorPositions.size() == 0) return boost::none;
auto it = factorPositions.find(i);
if (it == factorPositions.end()) return boost::none;
auto pos = it->second;
return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y()));
}
} // namespace gtsam

View File

@ -0,0 +1,69 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file GraphvizFormatting.h
* @brief Graphviz formatter for NonlinearFactorGraph
* @author Frank Dellaert
* @date December, 2021
*/
#pragma once
#include <gtsam/inference/DotWriter.h>
namespace gtsam {
class Values;
class Value;
/**
* Formatting options and functions for saving a NonlinearFactorGraph instance
* in GraphViz format.
*/
struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
/// World axes to be assigned to paper axes
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis
double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
///< connectivity
/// (optional for each factor) Manually specify factor "dot" positions:
std::map<size_t, Vector2> factorPositions;
/// Default constructor sets up robot coordinates. Paper horizontal is robot
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
GraphvizFormatting()
: paperHorizontalAxis(Y),
paperVerticalAxis(X),
scale(1),
mergeSimilarFactors(false) {}
// Find bounds
Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const;
/// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
Key key) const;
/// Return affinely transformed factor position if it exists.
boost::optional<Vector2> factorPos(const Vector2& min, size_t i) const;
};
} // namespace gtsam

View File

@ -26,6 +26,7 @@
#include <gtsam/linear/linearExceptions.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/inference/DotWriter.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/config.h> // for GTSAM_USE_TBB
@ -35,7 +36,6 @@
#include <cmath>
#include <fstream>
#include <limits>
using namespace std;
@ -91,89 +91,25 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
}
/* ************************************************************************* */
void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values,
const GraphvizFormatting& formatting,
const KeyFormatter& keyFormatter) const
{
stm << "graph {\n";
stm << " size=\"" << formatting.figureWidthInches << "," <<
formatting.figureHeightInches << "\";\n\n";
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const GraphvizFormatting& writer,
const KeyFormatter& keyFormatter) const {
writer.writePreamble(&os);
// Find bounds (imperative)
KeySet keys = this->keys();
// Local utility function to extract x and y coordinates
struct { boost::optional<Point2> operator()(
const Value& value, const GraphvizFormatting& graphvizFormatting)
{
Vector3 t;
if (const GenericValue<Pose2>* p = dynamic_cast<const GenericValue<Pose2>*>(&value)) {
t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Point2>* p = dynamic_cast<const GenericValue<Point2>*>(&value)) {
t << p->value().x(), p->value().y(), 0;
} else if (const GenericValue<Pose3>* p = dynamic_cast<const GenericValue<Pose3>*>(&value)) {
t = p->value().translation();
} else if (const GenericValue<Point3>* p = dynamic_cast<const GenericValue<Point3>*>(&value)) {
t = p->value();
} else {
return boost::none;
}
double x, y;
switch (graphvizFormatting.paperHorizontalAxis) {
case GraphvizFormatting::X: x = t.x(); break;
case GraphvizFormatting::Y: x = t.y(); break;
case GraphvizFormatting::Z: x = t.z(); break;
case GraphvizFormatting::NEGX: x = -t.x(); break;
case GraphvizFormatting::NEGY: x = -t.y(); break;
case GraphvizFormatting::NEGZ: x = -t.z(); break;
default: throw std::runtime_error("Invalid enum value");
}
switch (graphvizFormatting.paperVerticalAxis) {
case GraphvizFormatting::X: y = t.x(); break;
case GraphvizFormatting::Y: y = t.y(); break;
case GraphvizFormatting::Z: y = t.z(); break;
case GraphvizFormatting::NEGX: y = -t.x(); break;
case GraphvizFormatting::NEGY: y = -t.y(); break;
case GraphvizFormatting::NEGZ: y = -t.z(); break;
default: throw std::runtime_error("Invalid enum value");
}
return Point2(x,y);
}} getXY;
// Find bounds
double minX = numeric_limits<double>::infinity(), maxX = -numeric_limits<double>::infinity();
double minY = numeric_limits<double>::infinity(), maxY = -numeric_limits<double>::infinity();
for (const Key& key : keys) {
if (values.exists(key)) {
boost::optional<Point2> xy = getXY(values.at(key), formatting);
if(xy) {
if(xy->x() < minX)
minX = xy->x();
if(xy->x() > maxX)
maxX = xy->x();
if(xy->y() < minY)
minY = xy->y();
if(xy->y() > maxY)
maxY = xy->y();
}
}
}
Vector2 min = writer.findBounds(values, keys);
// Create nodes for each variable in the graph
for(Key key: keys){
// Label the node with the label from the KeyFormatter
stm << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
if(values.exists(key)) {
boost::optional<Point2> xy = getXY(values.at(key), formatting);
if(xy)
stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\"";
}
stm << "];\n";
for (Key key : keys) {
auto position = writer.variablePos(values, min, key);
writer.DrawVariable(key, keyFormatter, position, &os);
}
stm << "\n";
os << "\n";
if (formatting.mergeSimilarFactors) {
if (writer.mergeSimilarFactors) {
// Remove duplicate factors
std::set<KeyVector > structure;
std::set<KeyVector> structure;
for (const sharedFactor& factor : factors_) {
if (factor) {
KeyVector factorKeys = factor->keys();
@ -184,86 +120,40 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values,
// Create factors and variable connections
size_t i = 0;
for(const KeyVector& factorKeys: structure){
// Make each factor a dot
stm << " factor" << i << "[label=\"\", shape=point";
{
map<size_t, Point2>::const_iterator pos = formatting.factorPositions.find(i);
if(pos != formatting.factorPositions.end())
stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << ","
<< formatting.scale*(pos->second.y() - minY) << "!\"";
}
stm << "];\n";
// Make factor-variable connections
for(Key key: factorKeys) {
stm << " var" << key << "--" << "factor" << i << ";\n";
}
++ i;
for (const KeyVector& factorKeys : structure) {
writer.processFactor(i++, factorKeys, boost::none, &os);
}
} else {
// Create factors and variable connections
for(size_t i = 0; i < size(); ++i) {
for (size_t i = 0; i < size(); ++i) {
const NonlinearFactor::shared_ptr& factor = at(i);
// If null pointer, move on to the next
if (!factor) {
continue;
}
if (formatting.plotFactorPoints) {
const KeyVector& keys = factor->keys();
if (formatting.binaryEdges && keys.size() == 2) {
stm << " var" << keys[0] << "--"
<< "var" << keys[1] << ";\n";
} else {
// Make each factor a dot
stm << " factor" << i << "[label=\"\", shape=point";
{
map<size_t, Point2>::const_iterator pos =
formatting.factorPositions.find(i);
if (pos != formatting.factorPositions.end())
stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX)
<< "," << formatting.scale * (pos->second.y() - minY)
<< "!\"";
}
stm << "];\n";
// Make factor-variable connections
if (formatting.connectKeysToFactor && factor) {
for (Key key : *factor) {
stm << " var" << key << "--"
<< "factor" << i << ";\n";
}
}
}
} else {
Key k;
bool firstTime = true;
for (Key key : *this->at(i)) {
if (firstTime) {
k = key;
firstTime = false;
continue;
}
stm << " var" << key << "--"
<< "var" << k << ";\n";
k = key;
}
if (factor) {
const KeyVector& factorKeys = factor->keys();
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os);
}
}
}
stm << "}\n";
os << "}\n";
std::flush(os);
}
/* ************************************************************************* */
std::string NonlinearFactorGraph::dot(
const Values& values, const GraphvizFormatting& writer,
const KeyFormatter& keyFormatter) const {
std::stringstream ss;
dot(ss, values, writer, keyFormatter);
return ss.str();
}
/* ************************************************************************* */
void NonlinearFactorGraph::saveGraph(
const std::string& file, const Values& values,
const GraphvizFormatting& graphvizFormatting,
const std::string& filename, const Values& values,
const GraphvizFormatting& writer,
const KeyFormatter& keyFormatter) const {
std::ofstream of(file);
saveGraph(of, values, graphvizFormatting, keyFormatter);
std::ofstream of(filename);
dot(of, values, writer, keyFormatter);
of.close();
}

View File

@ -23,6 +23,7 @@
#include <gtsam/geometry/Point2.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/nonlinear/GraphvizFormatting.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/nonlinear/PriorFactor.h>
@ -41,32 +42,6 @@ namespace gtsam {
template<typename T>
class ExpressionFactor;
/**
* Formatting options when saving in GraphViz format using
* NonlinearFactorGraph::saveGraph.
*/
struct GTSAM_EXPORT GraphvizFormatting {
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis
double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches
double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity
bool plotFactorPoints; ///< Plots each factor as a dot between the variables
bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors
std::map<size_t, Point2> factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions.
/// Default constructor sets up robot coordinates. Paper horizontal is robot Y,
/// paper vertical is robot X. Default figure size of 5x5 in.
GraphvizFormatting() :
paperHorizontalAxis(Y), paperVerticalAxis(X),
figureWidthInches(5), figureHeightInches(5), scale(1),
mergeSimilarFactors(false), plotFactorPoints(true),
connectKeysToFactor(true), binaryEdges(true) {}
};
/**
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
@ -115,21 +90,6 @@ namespace gtsam {
/** Test equality */
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
/// Write the graph in GraphViz format for visualization
void saveGraph(std::ostream& stm, const Values& values = Values(),
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/**
* Write the graph in GraphViz format to file for visualization.
*
* This is a wrapper friendly version since wrapped languages don't have
* access to C++ streams.
*/
void saveGraph(const std::string& file, const Values& values = Values(),
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */
double error(const Values& values) const;
@ -246,7 +206,33 @@ namespace gtsam {
emplace_shared<PriorFactor<T>>(key, prior, covariance);
}
private:
/// @name Graph Display
/// @{
using FactorGraph::dot;
using FactorGraph::saveGraph;
/// Output to graphviz format, stream version, with Values/extra options.
void dot(
std::ostream& os, const Values& values,
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string, with Values/extra options.
std::string dot(
const Values& values,
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format, with Values/extra options.
void saveGraph(
const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private:
/**
* Linearize from Scatter rather than from Ordering. Made private because
@ -275,6 +261,14 @@ namespace gtsam {
Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t,
const Dampen& dampen = nullptr) const
{return updateCholesky(values, dampen);}
/** \deprecated */
void GTSAM_DEPRECATED saveGraph(
std::ostream& os, const Values& values = Values(),
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
dot(os, values, graphvizFormatting, keyFormatter);
}
#endif
};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -15,6 +15,7 @@
* @brief testNonlinearFactorGraph
* @author Carlos Nieto
* @author Christian Potthast
* @author Frank Dellaert
*/
#include <gtsam/base/Testable.h>
@ -285,6 +286,7 @@ TEST(testNonlinearFactorGraph, addPrior) {
EXPECT(0 != graph.error(values));
}
/* ************************************************************************* */
TEST(NonlinearFactorGraph, printErrors)
{
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
@ -309,6 +311,53 @@ TEST(NonlinearFactorGraph, printErrors)
for (bool visit : visited) EXPECT(visit==true);
}
/* ************************************************************************* */
TEST(NonlinearFactorGraph, dot) {
string expected =
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var7782220156096217089[label=\"l1\"];\n"
" var8646911284551352321[label=\"x1\"];\n"
" var8646911284551352322[label=\"x2\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n"
" var8646911284551352321--var7782220156096217089;\n"
" var8646911284551352322--var7782220156096217089;\n"
"}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
string actual = fg.dot();
EXPECT(actual == expected);
}
/* ************************************************************************* */
TEST(NonlinearFactorGraph, dot_extra) {
string expected =
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n"
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n"
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n"
"\n"
" factor0[label=\"\", shape=point];\n"
" var8646911284551352321--factor0;\n"
" var8646911284551352321--var8646911284551352322;\n"
" var8646911284551352321--var7782220156096217089;\n"
" var8646911284551352322--var7782220156096217089;\n"
"}\n";
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
const Values c = createValues();
stringstream ss;
fg.dot(ss, c);
EXPECT(ss.str() == expected);
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */