Merge pull request #971 from borglab/feature/notebook_dot
commit
168a67da05
|
@ -173,7 +173,7 @@ TEST(Matrix, stack )
|
|||
{
|
||||
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
|
||||
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
|
||||
Matrix AB = stack(2, &A, &B);
|
||||
Matrix AB = gtsam::stack(2, &A, &B);
|
||||
Matrix C(5, 2);
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int j = 0; j < 2; j++)
|
||||
|
@ -187,7 +187,7 @@ TEST(Matrix, stack )
|
|||
std::vector<gtsam::Matrix> matrices;
|
||||
matrices.push_back(A);
|
||||
matrices.push_back(B);
|
||||
Matrix AB2 = stack(matrices);
|
||||
Matrix AB2 = gtsam::stack(matrices);
|
||||
EQUALITY(C,AB2);
|
||||
}
|
||||
|
||||
|
|
|
@ -248,8 +248,9 @@ namespace gtsam {
|
|||
void dot(std::ostream& os, bool showZero) const override {
|
||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
||||
<< "\"]\n";
|
||||
for (size_t i = 0; i < branches_.size(); i++) {
|
||||
NodePtr branch = branches_[i];
|
||||
size_t B = branches_.size();
|
||||
for (size_t i = 0; i < B; i++) {
|
||||
const NodePtr& branch = branches_[i];
|
||||
|
||||
// Check if zero
|
||||
if (!showZero) {
|
||||
|
@ -258,8 +259,10 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||
if (B == 2) {
|
||||
if (i == 0) os << " [style=dashed]";
|
||||
if (i > 1) os << " [style=bold]";
|
||||
}
|
||||
os << std::endl;
|
||||
branch->dot(os, showZero);
|
||||
}
|
||||
|
@ -673,6 +676,13 @@ namespace gtsam {
|
|||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
}
|
||||
|
||||
template<typename L, typename Y>
|
||||
std::string DecisionTree<L, Y>::dot(bool showZero) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, showZero);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -198,6 +198,9 @@ namespace gtsam {
|
|||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name, bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(bool showZero = true) const;
|
||||
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
string dot(bool showZero = false) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
@ -52,8 +52,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
size_t size() const; // TODO(dellaert): why do I have to repeat???
|
||||
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
|
||||
void print(string s = "Discrete Conditional\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -82,6 +80,8 @@ class DiscreteBayesNet {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -98,9 +98,19 @@ class DiscreteBayesTree {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
class DotWriter {
|
||||
DotWriter();
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
class DiscreteFactorGraph {
|
||||
DiscreteFactorGraph();
|
||||
|
@ -118,6 +128,14 @@ class DiscreteFactorGraph {
|
|||
void print(string s = "") const;
|
||||
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
|
||||
|
||||
string dot(const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void saveGraph(string s,
|
||||
const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
||||
gtsam::DecisionTreeFactor product() const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
|
|
@ -135,7 +135,7 @@ TEST(DiscreteBayesNet, Asia) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
||||
TEST(DiscreteBayesNet, Sugar) {
|
||||
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);
|
||||
|
||||
DiscreteBayesNet bn;
|
||||
|
@ -149,6 +149,29 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
|
|||
bn.add(C | S = "1/1/2 5/2/3");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Dot) {
|
||||
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
|
||||
Either(5, 2);
|
||||
|
||||
DiscreteBayesNet fragment;
|
||||
fragment.add(Asia % "99/1");
|
||||
fragment.add(Smoking % "50/50");
|
||||
|
||||
fragment.add(Tuberculosis | Asia = "99/1 95/5");
|
||||
fragment.add(LungCancer | Smoking = "99/1 90/10");
|
||||
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
|
||||
|
||||
string actual = fragment.dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0->3\n"
|
||||
"4->6\n"
|
||||
"3->5\n"
|
||||
"6->5\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -26,76 +26,89 @@ using namespace boost::assign;
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static bool debug = false;
|
||||
static constexpr bool debug = false;
|
||||
|
||||
/* ************************************************************************* */
|
||||
struct TestFixture {
|
||||
vector<DiscreteKey> keys;
|
||||
DiscreteBayesNet bayesNet;
|
||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||
|
||||
TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
||||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
||||
// define variables
|
||||
vector<DiscreteKey> key;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey key_i(i, nrStates);
|
||||
key.push_back(key_i);
|
||||
/**
|
||||
* Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student),
|
||||
* and then create the Bayes tree from it.
|
||||
*/
|
||||
TestFixture() {
|
||||
// Define variables.
|
||||
for (int i = 0; i < 15; i++) {
|
||||
DiscreteKey key_i(i, 2);
|
||||
keys.push_back(key_i);
|
||||
}
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(key[14] % "1/3");
|
||||
// Create thin-tree Bayesnet.
|
||||
bayesNet.add(keys[14] % "1/3");
|
||||
|
||||
bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
||||
bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
||||
bayesNet.add(keys[13] | keys[14] = "1/3 3/1");
|
||||
bayesNet.add(keys[12] | keys[14] = "3/1 3/1");
|
||||
|
||||
bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||
bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4");
|
||||
bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1");
|
||||
bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||
bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1");
|
||||
bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1");
|
||||
bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4");
|
||||
bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
// Create a BayesTree out of the Bayes net.
|
||||
bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, ThinTree) {
|
||||
const TestFixture self;
|
||||
const auto& keys = self.keys;
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
GTSAM_PRINT(self.bayesNet);
|
||||
self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal();
|
||||
if (debug) {
|
||||
GTSAM_PRINT(*bayesTree);
|
||||
bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
GTSAM_PRINT(*self.bayesTree);
|
||||
self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot");
|
||||
}
|
||||
|
||||
// Check frontals and parents
|
||||
for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) {
|
||||
auto clique_i = (*bayesTree)[i];
|
||||
auto clique_i = (*self.bayesTree)[i];
|
||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||
}
|
||||
|
||||
auto R = bayesTree->roots().front();
|
||||
auto R = self.bayesTree->roots().front();
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
vector<DiscreteValues> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] &
|
||||
key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
vector<DiscreteValues> allPosbValues =
|
||||
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
|
||||
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
|
||||
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double expected = bayesNet.evaluate(x);
|
||||
double actual = bayesTree->evaluate(x);
|
||||
double expected = self.bayesNet.evaluate(x);
|
||||
double actual = self.bayesTree->evaluate(x);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
}
|
||||
|
||||
|
@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteValues x = allPosbValues[i];
|
||||
double px = bayesTree->evaluate(x);
|
||||
double px = self.bayesTree->evaluate(x);
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i]) marginals[i] += px;
|
||||
if (x[12] && x[14]) {
|
||||
|
@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
DiscreteValues all1 = allPosbValues.back();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
auto clique = (*bayesTree)[0];
|
||||
auto clique = (*self.bayesTree)[0];
|
||||
DiscreteFactorGraph separatorMarginal0 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteFactorGraph separatorMarginal9 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
clique = (*bayesTree)[11];
|
||||
clique = (*self.bayesTree)[11];
|
||||
DiscreteFactorGraph separatorMarginal11 =
|
||||
clique->separatorMarginal(EliminateDiscrete);
|
||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
clique = (*bayesTree)[9];
|
||||
clique = (*self.bayesTree)[9];
|
||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
LONGS_EQUAL(1, shortcut.size());
|
||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S8||R) to root
|
||||
clique = (*bayesTree)[8];
|
||||
clique = (*self.bayesTree)[8];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S2||R) to root
|
||||
clique = (*bayesTree)[2];
|
||||
clique = (*self.bayesTree)[2];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
clique = (*bayesTree)[0];
|
||||
clique = (*self.bayesTree)[0];
|
||||
shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = bayesTree->nodes();
|
||||
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
|
||||
for (auto clique : cliques) {
|
||||
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
|
@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
|
@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) {
|
|||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(1, 2)
|
||||
actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(2, 4)
|
||||
actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 5)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 6)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9);
|
||||
|
||||
// Check joint P(4, 11)
|
||||
actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete);
|
||||
DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, Dot) {
|
||||
const TestFixture self;
|
||||
string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13,11,6,7\"];\n"
|
||||
"0->1\n"
|
||||
"1[label=\"14 : 11,13\"];\n"
|
||||
"1->2\n"
|
||||
"2[label=\"9,12 : 14\"];\n"
|
||||
"2->3\n"
|
||||
"3[label=\"3 : 9,12\"];\n"
|
||||
"2->4\n"
|
||||
"4[label=\"2 : 9,12\"];\n"
|
||||
"2->5\n"
|
||||
"5[label=\"8 : 12,14\"];\n"
|
||||
"5->6\n"
|
||||
"6[label=\"1 : 8,12\"];\n"
|
||||
"5->7\n"
|
||||
"7[label=\"0 : 8,12\"];\n"
|
||||
"1->8\n"
|
||||
"8[label=\"10 : 13,14\"];\n"
|
||||
"8->9\n"
|
||||
"9[label=\"5 : 10,13\"];\n"
|
||||
"8->10\n"
|
||||
"10[label=\"4 : 10,13\"];\n"
|
||||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -359,6 +359,31 @@ cout << unicorns;
|
|||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, Dot) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
|
||||
// Create Factor graph
|
||||
DiscreteFactorGraph graph;
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
|
||||
string actual = graph.dot();
|
||||
string expected =
|
||||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var0[label=\"0\"];\n"
|
||||
" var1[label=\"1\"];\n"
|
||||
" var2[label=\"2\"];\n"
|
||||
"\n"
|
||||
" var0--var1;\n"
|
||||
" var0--var2;\n"
|
||||
"}\n";
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -35,21 +35,39 @@ void BayesNet<CONDITIONAL>::print(
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& s,
|
||||
void BayesNet<CONDITIONAL>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(s.c_str());
|
||||
of << "digraph G{\n";
|
||||
os << "digraph G{\n";
|
||||
|
||||
for (auto conditional : boost::adaptors::reverse(*this)) {
|
||||
typename CONDITIONAL::Frontals frontals = conditional->frontals();
|
||||
Key me = frontals.front();
|
||||
typename CONDITIONAL::Parents parents = conditional->parents();
|
||||
for (Key p : parents)
|
||||
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
|
||||
for (auto conditional : *this) {
|
||||
auto frontals = conditional->frontals();
|
||||
const Key me = frontals.front();
|
||||
auto parents = conditional->parents();
|
||||
for (const Key& p : parents)
|
||||
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n";
|
||||
}
|
||||
|
||||
of << "}";
|
||||
os << "}";
|
||||
std::flush(os);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CONDITIONAL>
|
||||
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, keyFormatter);
|
||||
of.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -64,11 +64,21 @@ namespace gtsam {
|
|||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
void saveGraph(const std::string& s,
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -64,19 +64,39 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
|
||||
if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
|
||||
std::ofstream of(s.c_str());
|
||||
of<< "digraph G{\n";
|
||||
for(const sharedClique& root: roots_)
|
||||
saveGraph(of, root, keyFormatter);
|
||||
of<<"}";
|
||||
void BayesTree<CLIQUE>::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
if (roots_.empty())
|
||||
throw std::invalid_argument(
|
||||
"the root of Bayes tree has not been initialized!");
|
||||
os << "digraph G{\n";
|
||||
for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
|
||||
os << "}";
|
||||
std::flush(os);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, keyFormatter);
|
||||
of.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class CLIQUE>
|
||||
void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
|
||||
void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
|
||||
const KeyFormatter& indexFormatter,
|
||||
int parentnum) const {
|
||||
static int num = 0;
|
||||
bool first = true;
|
||||
std::stringstream out;
|
||||
|
@ -107,7 +127,7 @@ namespace gtsam {
|
|||
|
||||
for (sharedClique c : clique->children) {
|
||||
num++;
|
||||
saveGraph(s, c, indexFormatter, parentnum);
|
||||
dot(s, c, indexFormatter, parentnum);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -182,12 +182,19 @@ namespace gtsam {
|
|||
*/
|
||||
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;
|
||||
|
||||
/**
|
||||
* Read only with side effects
|
||||
*/
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
/** saves the Tree to a text file in GraphViz format */
|
||||
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
|
@ -236,7 +243,7 @@ namespace gtsam {
|
|||
protected:
|
||||
|
||||
/** private helper method for saving the Tree to a text file in GraphViz format */
|
||||
void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
|
||||
void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter,
|
||||
int parentnum = 0) const;
|
||||
|
||||
/** Gather data on a single clique */
|
||||
|
@ -249,7 +256,7 @@ namespace gtsam {
|
|||
void fillNodesIndex(const sharedClique& subtree);
|
||||
|
||||
// Friend JunctionTree because it directly fills roots and nodes index.
|
||||
template<class BAYESRTEE, class GRAPH> friend class EliminatableClusterTree;
|
||||
template<class BAYESTREE, class GRAPH> friend class EliminatableClusterTree;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DotWriter.cpp
|
||||
* @brief Graphviz formatting for factor graphs.
|
||||
* @author Frank Dellaert
|
||||
* @date December, 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
|
||||
#include <ostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
void DotWriter::writePreamble(ostream* os) const {
|
||||
*os << "graph {\n";
|
||||
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
|
||||
<< "\";\n\n";
|
||||
}
|
||||
|
||||
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
ostream* os) {
|
||||
// Label the node with the label from the KeyFormatter
|
||||
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
|
||||
if (position) {
|
||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||
}
|
||||
*os << "];\n";
|
||||
}
|
||||
|
||||
void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||
ostream* os) {
|
||||
*os << " factor" << i << "[label=\"\", shape=point";
|
||||
if (position) {
|
||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||
}
|
||||
*os << "];\n";
|
||||
}
|
||||
|
||||
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) {
|
||||
*os << " var" << key1 << "--"
|
||||
<< "var" << key2 << ";\n";
|
||||
}
|
||||
|
||||
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) {
|
||||
*os << " var" << key << "--"
|
||||
<< "factor" << i << ";\n";
|
||||
}
|
||||
|
||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||
const boost::optional<Vector2>& position,
|
||||
ostream* os) const {
|
||||
if (plotFactorPoints) {
|
||||
if (binaryEdges && keys.size() == 2) {
|
||||
ConnectVariables(keys[0], keys[1], os);
|
||||
} else {
|
||||
// Create dot for the factor.
|
||||
DrawFactor(i, position, os);
|
||||
|
||||
// Make factor-variable connections
|
||||
if (connectKeysToFactor) {
|
||||
for (Key key : keys) {
|
||||
ConnectVariableFactor(key, i, os);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// just connect variables in a clique
|
||||
for (Key key1 : keys) {
|
||||
for (Key key2 : keys) {
|
||||
if (key2 > key1) {
|
||||
ConnectVariables(key1, key2, os);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,69 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DotWriter.h
|
||||
* @brief Graphviz formatter
|
||||
* @author Frank Dellaert
|
||||
* @date December, 2021
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/FastVector.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <iosfwd>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Graphviz formatter.
|
||||
struct GTSAM_EXPORT DotWriter {
|
||||
double figureWidthInches; ///< The figure width on paper in inches
|
||||
double figureHeightInches; ///< The figure height on paper in inches
|
||||
bool plotFactorPoints; ///< Plots each factor as a dot between the variables
|
||||
bool connectKeysToFactor; ///< Draw a line from each key within a factor to
|
||||
///< the dot of the factor
|
||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||
|
||||
DotWriter()
|
||||
: figureWidthInches(5),
|
||||
figureHeightInches(5),
|
||||
plotFactorPoints(true),
|
||||
connectKeysToFactor(true),
|
||||
binaryEdges(true) {}
|
||||
|
||||
/// Write out preamble, including size.
|
||||
void writePreamble(std::ostream* os) const;
|
||||
|
||||
/// Create a variable dot fragment.
|
||||
static void DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||
const boost::optional<Vector2>& position,
|
||||
std::ostream* os);
|
||||
|
||||
/// Create factor dot.
|
||||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||
std::ostream* os);
|
||||
|
||||
/// Connect two variables.
|
||||
static void ConnectVariables(Key key1, Key key2, std::ostream* os);
|
||||
|
||||
/// Connect variable and factor.
|
||||
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
|
||||
|
||||
/// Draw a single factor, specified by its index i and its variable keys.
|
||||
void processFactor(size_t i, const KeyVector& keys,
|
||||
const boost::optional<Vector2>& position,
|
||||
std::ostream* os) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
|
@ -26,6 +26,7 @@
|
|||
#include <stdio.h>
|
||||
#include <algorithm>
|
||||
#include <iostream> // for cout :-(
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
|
@ -125,4 +126,48 @@ FactorIndices FactorGraph<FACTOR>::add_factors(const CONTAINER& factors,
|
|||
return newFactorIndices;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR>
|
||||
void FactorGraph<FACTOR>::dot(std::ostream& os, const DotWriter& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
writer.writePreamble(&os);
|
||||
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : keys()) {
|
||||
writer.DrawVariable(key, keyFormatter, boost::none, &os);
|
||||
}
|
||||
os << "\n";
|
||||
|
||||
// Create factors and variable connections
|
||||
for (size_t i = 0; i < size(); ++i) {
|
||||
const auto& factor = at(i);
|
||||
if (factor) {
|
||||
const KeyVector& factorKeys = factor->keys();
|
||||
writer.processFactor(i, factorKeys, boost::none, &os);
|
||||
}
|
||||
}
|
||||
|
||||
os << "}\n";
|
||||
std::flush(os);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR>
|
||||
std::string FactorGraph<FACTOR>::dot(const DotWriter& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, writer, keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <class FACTOR>
|
||||
void FactorGraph<FACTOR>::saveGraph(const std::string& filename,
|
||||
const DotWriter& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(filename.c_str());
|
||||
dot(of, writer, keyFormatter);
|
||||
of.close();
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -22,9 +22,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/base/FastVector.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <Eigen/Core> // for Eigen::aligned_allocator
|
||||
|
||||
|
@ -36,6 +37,7 @@
|
|||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <iosfwd>
|
||||
|
||||
namespace gtsam {
|
||||
/// Define collection type:
|
||||
|
@ -371,6 +373,23 @@ class FactorGraph {
|
|||
return factors_.erase(first, last);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
/// Output to graphviz format, stream version.
|
||||
void dot(std::ostream& os, const DotWriter& writer = DotWriter(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string.
|
||||
std::string dot(const DotWriter& writer = DotWriter(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format.
|
||||
void saveGraph(const std::string& filename,
|
||||
const DotWriter& writer = DotWriter(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
|
|
@ -379,7 +379,7 @@ namespace gtsam {
|
|||
|
||||
gttic(Compute_minimizing_step_size);
|
||||
// Compute minimizing step size
|
||||
double step = -gradientSqNorm / dot(Rg, Rg);
|
||||
double step = -gradientSqNorm / gtsam::dot(Rg, Rg);
|
||||
gttoc(Compute_minimizing_step_size);
|
||||
|
||||
gttic(Compute_point);
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file GraphvizFormatting.cpp
|
||||
* @brief Graphviz formatter for NonlinearFactorGraph
|
||||
* @author Frank Dellaert
|
||||
* @date December, 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/nonlinear/GraphvizFormatting.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
|
||||
// TODO(frank): nonlinear should not depend on geometry:
|
||||
#include <gtsam/geometry/Pose2.h>
|
||||
#include <gtsam/geometry/Pose3.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
Vector2 GraphvizFormatting::findBounds(const Values& values,
|
||||
const KeySet& keys) const {
|
||||
Vector2 min;
|
||||
min.x() = std::numeric_limits<double>::infinity();
|
||||
min.y() = std::numeric_limits<double>::infinity();
|
||||
for (const Key& key : keys) {
|
||||
if (values.exists(key)) {
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
if (xy) {
|
||||
if (xy->x() < min.x()) min.x() = xy->x();
|
||||
if (xy->y() < min.y()) min.y() = xy->y();
|
||||
}
|
||||
}
|
||||
}
|
||||
return min;
|
||||
}
|
||||
|
||||
boost::optional<Vector2> GraphvizFormatting::operator()(
|
||||
const Value& value) const {
|
||||
Vector3 t;
|
||||
if (const GenericValue<Pose2>* p =
|
||||
dynamic_cast<const GenericValue<Pose2>*>(&value)) {
|
||||
t << p->value().x(), p->value().y(), 0;
|
||||
} else if (const GenericValue<Vector2>* p =
|
||||
dynamic_cast<const GenericValue<Vector2>*>(&value)) {
|
||||
t << p->value().x(), p->value().y(), 0;
|
||||
} else if (const GenericValue<Pose3>* p =
|
||||
dynamic_cast<const GenericValue<Pose3>*>(&value)) {
|
||||
t = p->value().translation();
|
||||
} else if (const GenericValue<Point3>* p =
|
||||
dynamic_cast<const GenericValue<Point3>*>(&value)) {
|
||||
t = p->value();
|
||||
} else {
|
||||
return boost::none;
|
||||
}
|
||||
double x, y;
|
||||
switch (paperHorizontalAxis) {
|
||||
case X:
|
||||
x = t.x();
|
||||
break;
|
||||
case Y:
|
||||
x = t.y();
|
||||
break;
|
||||
case Z:
|
||||
x = t.z();
|
||||
break;
|
||||
case NEGX:
|
||||
x = -t.x();
|
||||
break;
|
||||
case NEGY:
|
||||
x = -t.y();
|
||||
break;
|
||||
case NEGZ:
|
||||
x = -t.z();
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid enum value");
|
||||
}
|
||||
switch (paperVerticalAxis) {
|
||||
case X:
|
||||
y = t.x();
|
||||
break;
|
||||
case Y:
|
||||
y = t.y();
|
||||
break;
|
||||
case Z:
|
||||
y = t.z();
|
||||
break;
|
||||
case NEGX:
|
||||
y = -t.x();
|
||||
break;
|
||||
case NEGY:
|
||||
y = -t.y();
|
||||
break;
|
||||
case NEGZ:
|
||||
y = -t.z();
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Invalid enum value");
|
||||
}
|
||||
return Vector2(x, y);
|
||||
}
|
||||
|
||||
// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
|
||||
const Vector2& min,
|
||||
Key key) const {
|
||||
if (!values.exists(key)) return boost::none;
|
||||
boost::optional<Vector2> xy = operator()(values.at(key));
|
||||
if (xy) {
|
||||
xy->x() = scale * (xy->x() - min.x());
|
||||
xy->y() = scale * (xy->y() - min.y());
|
||||
}
|
||||
return xy;
|
||||
}
|
||||
|
||||
// Return affinely transformed factor position if it exists.
|
||||
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
|
||||
size_t i) const {
|
||||
if (factorPositions.size() == 0) return boost::none;
|
||||
auto it = factorPositions.find(i);
|
||||
if (it == factorPositions.end()) return boost::none;
|
||||
auto pos = it->second;
|
||||
return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y()));
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,69 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file GraphvizFormatting.h
|
||||
* @brief Graphviz formatter for NonlinearFactorGraph
|
||||
* @author Frank Dellaert
|
||||
* @date December, 2021
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class Values;
|
||||
class Value;
|
||||
|
||||
/**
|
||||
* Formatting options and functions for saving a NonlinearFactorGraph instance
|
||||
* in GraphViz format.
|
||||
*/
|
||||
struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
|
||||
/// World axes to be assigned to paper axes
|
||||
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
|
||||
|
||||
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
|
||||
///< paper axis
|
||||
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
|
||||
///< axis
|
||||
double scale; ///< Scale all positions to reduce / increase density
|
||||
bool mergeSimilarFactors; ///< Merge multiple factors that have the same
|
||||
///< connectivity
|
||||
|
||||
/// (optional for each factor) Manually specify factor "dot" positions:
|
||||
std::map<size_t, Vector2> factorPositions;
|
||||
|
||||
/// Default constructor sets up robot coordinates. Paper horizontal is robot
|
||||
/// Y, paper vertical is robot X. Default figure size of 5x5 in.
|
||||
GraphvizFormatting()
|
||||
: paperHorizontalAxis(Y),
|
||||
paperVerticalAxis(X),
|
||||
scale(1),
|
||||
mergeSimilarFactors(false) {}
|
||||
|
||||
// Find bounds
|
||||
Vector2 findBounds(const Values& values, const KeySet& keys) const;
|
||||
|
||||
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
|
||||
boost::optional<Vector2> operator()(const Value& value) const;
|
||||
|
||||
/// Return affinely transformed variable position if it exists.
|
||||
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,
|
||||
Key key) const;
|
||||
|
||||
/// Return affinely transformed factor position if it exists.
|
||||
boost::optional<Vector2> factorPos(const Vector2& min, size_t i) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
|
@ -26,6 +26,7 @@
|
|||
#include <gtsam/linear/linearExceptions.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <gtsam/config.h> // for GTSAM_USE_TBB
|
||||
|
||||
|
@ -35,7 +36,6 @@
|
|||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -91,87 +91,23 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values,
|
||||
const GraphvizFormatting& formatting,
|
||||
const KeyFormatter& keyFormatter) const
|
||||
{
|
||||
stm << "graph {\n";
|
||||
stm << " size=\"" << formatting.figureWidthInches << "," <<
|
||||
formatting.figureHeightInches << "\";\n\n";
|
||||
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||
const GraphvizFormatting& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
writer.writePreamble(&os);
|
||||
|
||||
// Find bounds (imperative)
|
||||
KeySet keys = this->keys();
|
||||
|
||||
// Local utility function to extract x and y coordinates
|
||||
struct { boost::optional<Point2> operator()(
|
||||
const Value& value, const GraphvizFormatting& graphvizFormatting)
|
||||
{
|
||||
Vector3 t;
|
||||
if (const GenericValue<Pose2>* p = dynamic_cast<const GenericValue<Pose2>*>(&value)) {
|
||||
t << p->value().x(), p->value().y(), 0;
|
||||
} else if (const GenericValue<Point2>* p = dynamic_cast<const GenericValue<Point2>*>(&value)) {
|
||||
t << p->value().x(), p->value().y(), 0;
|
||||
} else if (const GenericValue<Pose3>* p = dynamic_cast<const GenericValue<Pose3>*>(&value)) {
|
||||
t = p->value().translation();
|
||||
} else if (const GenericValue<Point3>* p = dynamic_cast<const GenericValue<Point3>*>(&value)) {
|
||||
t = p->value();
|
||||
} else {
|
||||
return boost::none;
|
||||
}
|
||||
double x, y;
|
||||
switch (graphvizFormatting.paperHorizontalAxis) {
|
||||
case GraphvizFormatting::X: x = t.x(); break;
|
||||
case GraphvizFormatting::Y: x = t.y(); break;
|
||||
case GraphvizFormatting::Z: x = t.z(); break;
|
||||
case GraphvizFormatting::NEGX: x = -t.x(); break;
|
||||
case GraphvizFormatting::NEGY: x = -t.y(); break;
|
||||
case GraphvizFormatting::NEGZ: x = -t.z(); break;
|
||||
default: throw std::runtime_error("Invalid enum value");
|
||||
}
|
||||
switch (graphvizFormatting.paperVerticalAxis) {
|
||||
case GraphvizFormatting::X: y = t.x(); break;
|
||||
case GraphvizFormatting::Y: y = t.y(); break;
|
||||
case GraphvizFormatting::Z: y = t.z(); break;
|
||||
case GraphvizFormatting::NEGX: y = -t.x(); break;
|
||||
case GraphvizFormatting::NEGY: y = -t.y(); break;
|
||||
case GraphvizFormatting::NEGZ: y = -t.z(); break;
|
||||
default: throw std::runtime_error("Invalid enum value");
|
||||
}
|
||||
return Point2(x,y);
|
||||
}} getXY;
|
||||
|
||||
// Find bounds
|
||||
double minX = numeric_limits<double>::infinity(), maxX = -numeric_limits<double>::infinity();
|
||||
double minY = numeric_limits<double>::infinity(), maxY = -numeric_limits<double>::infinity();
|
||||
for (const Key& key : keys) {
|
||||
if (values.exists(key)) {
|
||||
boost::optional<Point2> xy = getXY(values.at(key), formatting);
|
||||
if(xy) {
|
||||
if(xy->x() < minX)
|
||||
minX = xy->x();
|
||||
if(xy->x() > maxX)
|
||||
maxX = xy->x();
|
||||
if(xy->y() < minY)
|
||||
minY = xy->y();
|
||||
if(xy->y() > maxY)
|
||||
maxY = xy->y();
|
||||
}
|
||||
}
|
||||
}
|
||||
Vector2 min = writer.findBounds(values, keys);
|
||||
|
||||
// Create nodes for each variable in the graph
|
||||
for (Key key : keys) {
|
||||
// Label the node with the label from the KeyFormatter
|
||||
stm << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
|
||||
if(values.exists(key)) {
|
||||
boost::optional<Point2> xy = getXY(values.at(key), formatting);
|
||||
if(xy)
|
||||
stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\"";
|
||||
auto position = writer.variablePos(values, min, key);
|
||||
writer.DrawVariable(key, keyFormatter, position, &os);
|
||||
}
|
||||
stm << "];\n";
|
||||
}
|
||||
stm << "\n";
|
||||
os << "\n";
|
||||
|
||||
if (formatting.mergeSimilarFactors) {
|
||||
if (writer.mergeSimilarFactors) {
|
||||
// Remove duplicate factors
|
||||
std::set<KeyVector> structure;
|
||||
for (const sharedFactor& factor : factors_) {
|
||||
|
@ -185,85 +121,39 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values,
|
|||
// Create factors and variable connections
|
||||
size_t i = 0;
|
||||
for (const KeyVector& factorKeys : structure) {
|
||||
// Make each factor a dot
|
||||
stm << " factor" << i << "[label=\"\", shape=point";
|
||||
{
|
||||
map<size_t, Point2>::const_iterator pos = formatting.factorPositions.find(i);
|
||||
if(pos != formatting.factorPositions.end())
|
||||
stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << ","
|
||||
<< formatting.scale*(pos->second.y() - minY) << "!\"";
|
||||
}
|
||||
stm << "];\n";
|
||||
|
||||
// Make factor-variable connections
|
||||
for(Key key: factorKeys) {
|
||||
stm << " var" << key << "--" << "factor" << i << ";\n";
|
||||
}
|
||||
|
||||
++ i;
|
||||
writer.processFactor(i++, factorKeys, boost::none, &os);
|
||||
}
|
||||
} else {
|
||||
// Create factors and variable connections
|
||||
for (size_t i = 0; i < size(); ++i) {
|
||||
const NonlinearFactor::shared_ptr& factor = at(i);
|
||||
// If null pointer, move on to the next
|
||||
if (!factor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (formatting.plotFactorPoints) {
|
||||
const KeyVector& keys = factor->keys();
|
||||
if (formatting.binaryEdges && keys.size() == 2) {
|
||||
stm << " var" << keys[0] << "--"
|
||||
<< "var" << keys[1] << ";\n";
|
||||
} else {
|
||||
// Make each factor a dot
|
||||
stm << " factor" << i << "[label=\"\", shape=point";
|
||||
{
|
||||
map<size_t, Point2>::const_iterator pos =
|
||||
formatting.factorPositions.find(i);
|
||||
if (pos != formatting.factorPositions.end())
|
||||
stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX)
|
||||
<< "," << formatting.scale * (pos->second.y() - minY)
|
||||
<< "!\"";
|
||||
}
|
||||
stm << "];\n";
|
||||
|
||||
// Make factor-variable connections
|
||||
if (formatting.connectKeysToFactor && factor) {
|
||||
for (Key key : *factor) {
|
||||
stm << " var" << key << "--"
|
||||
<< "factor" << i << ";\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Key k;
|
||||
bool firstTime = true;
|
||||
for (Key key : *this->at(i)) {
|
||||
if (firstTime) {
|
||||
k = key;
|
||||
firstTime = false;
|
||||
continue;
|
||||
}
|
||||
stm << " var" << key << "--"
|
||||
<< "var" << k << ";\n";
|
||||
k = key;
|
||||
}
|
||||
if (factor) {
|
||||
const KeyVector& factorKeys = factor->keys();
|
||||
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stm << "}\n";
|
||||
os << "}\n";
|
||||
std::flush(os);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string NonlinearFactorGraph::dot(
|
||||
const Values& values, const GraphvizFormatting& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, values, writer, keyFormatter);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void NonlinearFactorGraph::saveGraph(
|
||||
const std::string& file, const Values& values,
|
||||
const GraphvizFormatting& graphvizFormatting,
|
||||
const std::string& filename, const Values& values,
|
||||
const GraphvizFormatting& writer,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
std::ofstream of(file);
|
||||
saveGraph(of, values, graphvizFormatting, keyFormatter);
|
||||
std::ofstream of(filename);
|
||||
dot(of, values, writer, keyFormatter);
|
||||
of.close();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include <gtsam/geometry/Point2.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
#include <gtsam/nonlinear/GraphvizFormatting.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/nonlinear/PriorFactor.h>
|
||||
|
||||
|
@ -41,32 +42,6 @@ namespace gtsam {
|
|||
template<typename T>
|
||||
class ExpressionFactor;
|
||||
|
||||
/**
|
||||
* Formatting options when saving in GraphViz format using
|
||||
* NonlinearFactorGraph::saveGraph.
|
||||
*/
|
||||
struct GTSAM_EXPORT GraphvizFormatting {
|
||||
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes
|
||||
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis
|
||||
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis
|
||||
double figureWidthInches; ///< The figure width on paper in inches
|
||||
double figureHeightInches; ///< The figure height on paper in inches
|
||||
double scale; ///< Scale all positions to reduce / increase density
|
||||
bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity
|
||||
bool plotFactorPoints; ///< Plots each factor as a dot between the variables
|
||||
bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor
|
||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||
std::map<size_t, Point2> factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions.
|
||||
/// Default constructor sets up robot coordinates. Paper horizontal is robot Y,
|
||||
/// paper vertical is robot X. Default figure size of 5x5 in.
|
||||
GraphvizFormatting() :
|
||||
paperHorizontalAxis(Y), paperVerticalAxis(X),
|
||||
figureWidthInches(5), figureHeightInches(5), scale(1),
|
||||
mergeSimilarFactors(false), plotFactorPoints(true),
|
||||
connectKeysToFactor(true), binaryEdges(true) {}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
|
||||
|
@ -115,21 +90,6 @@ namespace gtsam {
|
|||
/** Test equality */
|
||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||
|
||||
/// Write the graph in GraphViz format for visualization
|
||||
void saveGraph(std::ostream& stm, const Values& values = Values(),
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/**
|
||||
* Write the graph in GraphViz format to file for visualization.
|
||||
*
|
||||
* This is a wrapper friendly version since wrapped languages don't have
|
||||
* access to C++ streams.
|
||||
*/
|
||||
void saveGraph(const std::string& file, const Values& values = Values(),
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */
|
||||
double error(const Values& values) const;
|
||||
|
||||
|
@ -246,6 +206,32 @@ namespace gtsam {
|
|||
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
||||
}
|
||||
|
||||
/// @name Graph Display
|
||||
/// @{
|
||||
|
||||
using FactorGraph::dot;
|
||||
using FactorGraph::saveGraph;
|
||||
|
||||
/// Output to graphviz format, stream version, with Values/extra options.
|
||||
void dot(
|
||||
std::ostream& os, const Values& values,
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Output to graphviz format string, with Values/extra options.
|
||||
std::string dot(
|
||||
const Values& values,
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// output to file with graphviz format, with Values/extra options.
|
||||
void saveGraph(
|
||||
const std::string& filename, const Values& values,
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
||||
/**
|
||||
|
@ -275,6 +261,14 @@ namespace gtsam {
|
|||
Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t,
|
||||
const Dampen& dampen = nullptr) const
|
||||
{return updateCholesky(values, dampen);}
|
||||
|
||||
/** \deprecated */
|
||||
void GTSAM_DEPRECATED saveGraph(
|
||||
std::ostream& os, const Values& values = Values(),
|
||||
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
dot(os, values, graphvizFormatting, keyFormatter);
|
||||
}
|
||||
#endif
|
||||
|
||||
};
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -15,6 +15,7 @@
|
|||
* @brief testNonlinearFactorGraph
|
||||
* @author Carlos Nieto
|
||||
* @author Christian Potthast
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
@ -285,6 +286,7 @@ TEST(testNonlinearFactorGraph, addPrior) {
|
|||
EXPECT(0 != graph.error(values));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(NonlinearFactorGraph, printErrors)
|
||||
{
|
||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||
|
@ -309,6 +311,53 @@ TEST(NonlinearFactorGraph, printErrors)
|
|||
for (bool visit : visited) EXPECT(visit==true);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(NonlinearFactorGraph, dot) {
|
||||
string expected =
|
||||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var7782220156096217089[label=\"l1\"];\n"
|
||||
" var8646911284551352321[label=\"x1\"];\n"
|
||||
" var8646911284551352322[label=\"x2\"];\n"
|
||||
"\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" var8646911284551352321--factor0;\n"
|
||||
" var8646911284551352321--var8646911284551352322;\n"
|
||||
" var8646911284551352321--var7782220156096217089;\n"
|
||||
" var8646911284551352322--var7782220156096217089;\n"
|
||||
"}\n";
|
||||
|
||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||
string actual = fg.dot();
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(NonlinearFactorGraph, dot_extra) {
|
||||
string expected =
|
||||
"graph {\n"
|
||||
" size=\"5,5\";\n"
|
||||
"\n"
|
||||
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n"
|
||||
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n"
|
||||
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n"
|
||||
"\n"
|
||||
" factor0[label=\"\", shape=point];\n"
|
||||
" var8646911284551352321--factor0;\n"
|
||||
" var8646911284551352321--var8646911284551352322;\n"
|
||||
" var8646911284551352321--var7782220156096217089;\n"
|
||||
" var8646911284551352322--var7782220156096217089;\n"
|
||||
"}\n";
|
||||
|
||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||
const Values c = createValues();
|
||||
|
||||
stringstream ss;
|
||||
fg.dot(ss, c);
|
||||
EXPECT(ss.str() == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue