Merge pull request #1045 from borglab/feature/discrete_wrapping
commit
d8abdc280d
|
@ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "):\n";
|
cout << "):\n";
|
||||||
ADT::print("");
|
ADT::print("", formatter);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
else {
|
} else {
|
||||||
const DecisionTreeFactor& f(
|
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
static_cast<const DecisionTreeFactor&>(other));
|
|
||||||
return DecisionTreeFactor::equals(f, tol);
|
return DecisionTreeFactor::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& parentsValues) {
|
const DiscreteValues& given,
|
||||||
|
bool forceComplete = true) {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
DiscreteConditional::ADT adt(conditional);
|
DiscreteConditional::ADT adt(conditional);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : conditional.parents()) {
|
for (Key j : conditional.parents()) {
|
||||||
try {
|
try {
|
||||||
value = parentsValues.at(j);
|
value = given.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
} catch (std::out_of_range&) {
|
} catch (std::out_of_range&) {
|
||||||
parentsValues.print("parentsValues: ");
|
if (forceComplete) {
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
given.print("parentsValues: ");
|
||||||
};
|
throw runtime_error(
|
||||||
|
"DiscreteConditional::Choose: parent value missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return adt;
|
return adt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& parentsValues) const {
|
const DiscreteValues& given) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
||||||
// branches based on the value of the parent variables.
|
|
||||||
ADT adt(*this);
|
|
||||||
size_t value;
|
|
||||||
for (Key j : parents()) {
|
|
||||||
try {
|
|
||||||
value = parentsValues.at(j);
|
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
||||||
} catch (exception&) {
|
|
||||||
parentsValues.print("parentsValues: ");
|
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Collect all keys not in given.
|
||||||
DiscreteKeys discreteKeys;
|
DiscreteKeys dKeys;
|
||||||
for (Key j : frontals()) {
|
for (Key j : frontals()) {
|
||||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||||
|
Key j = keys_[i];
|
||||||
|
if (given.count(j) == 0) {
|
||||||
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
const DiscreteValues& frontalValues) const {
|
const DiscreteValues& frontalValues) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
} catch (exception&) {
|
} catch (exception&) {
|
||||||
frontalValues.print("frontalValues: ");
|
frontalValues.print("frontalValues: ");
|
||||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Convert ADT to factor.
|
||||||
|
@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
|
@ -276,10 +272,8 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
|
|
||||||
// TODO: is this really the fastest way? I think it is.
|
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
|
|
|
@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/**
|
||||||
DecisionTreeFactor::shared_ptr choose(
|
* @brief restrict to given *parent* values.
|
||||||
const DiscreteValues& parentsValues) const;
|
*
|
||||||
|
* Note: does not need be complete set. Examples:
|
||||||
|
*
|
||||||
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
|
* P(C|D,E) + E -> P(C|D)
|
||||||
|
* P(C|D,E) + D -> P(C|E)
|
||||||
|
* P(C|D,E) + D,E -> P(C)
|
||||||
|
* P(C|D,E) + C -> error!
|
||||||
|
*
|
||||||
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
|
*/
|
||||||
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
|
||||||
/** Convert to a likelihood factor by providing value before bar. */
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(
|
DecisionTreeFactor::shared_ptr likelihood(
|
||||||
|
|
|
@ -64,32 +64,34 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
|
: public FactorGraph<DiscreteFactor>,
|
||||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||||
public:
|
public:
|
||||||
|
using This = DiscreteFactorGraph; ///< this class
|
||||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
|
||||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
using BaseEliminateable =
|
||||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
EliminateableFactorGraph<This>; ///< for elimination
|
||||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This
|
||||||
|
|
||||||
using Values = DiscreteValues; ///< backwards compatibility
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
/** A map from keys to values */
|
using Indices = KeyVector; ///> map from keys to values
|
||||||
typedef KeyVector Indices;
|
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
DiscreteFactorGraph() {}
|
DiscreteFactorGraph() {}
|
||||||
|
|
||||||
/** Construct from iterator over factors */
|
/** Construct from iterator over factors */
|
||||||
template <typename ITERATOR>
|
template <typename ITERATOR>
|
||||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
|
||||||
|
: Base(firstFactor, lastFactor) {}
|
||||||
|
|
||||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||||
template <class CONTAINER>
|
template <class CONTAINER>
|
||||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
|
* constructor */
|
||||||
template <class DERIVEDFACTOR>
|
template <class DERIVEDFACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
|
@ -166,6 +168,7 @@ public:
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
template <>
|
||||||
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
||||||
} // \ namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* choose(
|
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
const gtsam::DiscreteValues& frontalValues) const;
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
|
@ -230,11 +229,16 @@ class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
void add(const gtsam::DiscreteKey& j, string table);
|
// Building the graph
|
||||||
|
void push_back(const gtsam::DiscreteFactor* factor);
|
||||||
|
void push_back(const gtsam::DiscreteConditional* conditional);
|
||||||
|
void push_back(const gtsam::DiscreteFactorGraph& graph);
|
||||||
|
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
|
||||||
|
void add(const gtsam::DiscreteKey& j, string spec);
|
||||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string table);
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -258,8 +262,12 @@ class DiscreteFactorGraph {
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet eliminateSequential();
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
||||||
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
gtsam::DiscreteBayesTree eliminateMultifrontal();
|
||||||
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
|
||||||
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
|
||||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check choose on P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, choose) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// Case 1: no given values: no-op
|
||||||
|
DiscreteValues given;
|
||||||
|
auto actual1 = C_given_DE.choose(given);
|
||||||
|
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 1 given value
|
||||||
|
given[D.first] = 1;
|
||||||
|
auto actual2 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||||
|
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||||
|
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 2 given values
|
||||||
|
given[E.first] = 0;
|
||||||
|
auto actual3 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||||
|
DiscreteConditional expected3(C % "1/1");
|
||||||
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
|
@ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) {
|
||||||
" var1[label=\"1\"];\n"
|
" var1[label=\"1\"];\n"
|
||||||
" var2[label=\"2\"];\n"
|
" var2[label=\"2\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0--var1;\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var0--var2;\n"
|
" var0--factor0;\n"
|
||||||
|
" var1--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" var0--factor1;\n"
|
||||||
|
" var2--factor1;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
@ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0[label=\"C\"];\n"
|
" varC[label=\"C\"];\n"
|
||||||
" var1[label=\"A\"];\n"
|
" varA[label=\"A\"];\n"
|
||||||
" var2[label=\"B\"];\n"
|
" varB[label=\"B\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var0--var1;\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var0--var2;\n"
|
" varC--factor0;\n"
|
||||||
|
" varA--factor0;\n"
|
||||||
|
" factor1[label=\"\", shape=point];\n"
|
||||||
|
" varC--factor1;\n"
|
||||||
|
" varB--factor1;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,8 @@ void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
ostream* os) {
|
ostream* os) {
|
||||||
// Label the node with the label from the KeyFormatter
|
// Label the node with the label from the KeyFormatter
|
||||||
*os << " var" << key << "[label=\"" << keyFormatter(key) << "\"";
|
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
|
||||||
|
<< "\"";
|
||||||
if (position) {
|
if (position) {
|
||||||
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
*os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
|
||||||
}
|
}
|
||||||
|
@ -51,22 +52,26 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
*os << "];\n";
|
*os << "];\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) {
|
static void ConnectVariables(Key key1, Key key2,
|
||||||
*os << " var" << key1 << "--"
|
const KeyFormatter& keyFormatter,
|
||||||
<< "var" << key2 << ";\n";
|
ostream* os) {
|
||||||
|
*os << " var" << keyFormatter(key1) << "--"
|
||||||
|
<< "var" << keyFormatter(key2) << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) {
|
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
||||||
*os << " var" << key << "--"
|
size_t i, ostream* os) {
|
||||||
|
*os << " var" << keyFormatter(key) << "--"
|
||||||
<< "factor" << i << ";\n";
|
<< "factor" << i << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
ostream* os) const {
|
ostream* os) const {
|
||||||
if (plotFactorPoints) {
|
if (plotFactorPoints) {
|
||||||
if (binaryEdges && keys.size() == 2) {
|
if (binaryEdges && keys.size() == 2) {
|
||||||
ConnectVariables(keys[0], keys[1], os);
|
ConnectVariables(keys[0], keys[1], keyFormatter, os);
|
||||||
} else {
|
} else {
|
||||||
// Create dot for the factor.
|
// Create dot for the factor.
|
||||||
DrawFactor(i, position, os);
|
DrawFactor(i, position, os);
|
||||||
|
@ -74,7 +79,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
// Make factor-variable connections
|
// Make factor-variable connections
|
||||||
if (connectKeysToFactor) {
|
if (connectKeysToFactor) {
|
||||||
for (Key key : keys) {
|
for (Key key : keys) {
|
||||||
ConnectVariableFactor(key, i, os);
|
ConnectVariableFactor(key, keyFormatter, i, os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,7 +88,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
for (Key key1 : keys) {
|
for (Key key1 : keys) {
|
||||||
for (Key key2 : keys) {
|
for (Key key2 : keys) {
|
||||||
if (key2 > key1) {
|
if (key2 > key1) {
|
||||||
ConnectVariables(key1, key2, os);
|
ConnectVariables(key1, key2, keyFormatter, os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
explicit DotWriter(double figureWidthInches = 5,
|
explicit DotWriter(double figureWidthInches = 5,
|
||||||
double figureHeightInches = 5,
|
double figureHeightInches = 5,
|
||||||
bool plotFactorPoints = true,
|
bool plotFactorPoints = true,
|
||||||
bool connectKeysToFactor = true, bool binaryEdges = true)
|
bool connectKeysToFactor = true, bool binaryEdges = false)
|
||||||
: figureWidthInches(figureWidthInches),
|
: figureWidthInches(figureWidthInches),
|
||||||
figureHeightInches(figureHeightInches),
|
figureHeightInches(figureHeightInches),
|
||||||
plotFactorPoints(plotFactorPoints),
|
plotFactorPoints(plotFactorPoints),
|
||||||
|
@ -57,14 +57,9 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
std::ostream* os);
|
std::ostream* os);
|
||||||
|
|
||||||
/// Connect two variables.
|
|
||||||
static void ConnectVariables(Key key1, Key key2, std::ostream* os);
|
|
||||||
|
|
||||||
/// Connect variable and factor.
|
|
||||||
static void ConnectVariableFactor(Key key, size_t i, std::ostream* os);
|
|
||||||
|
|
||||||
/// Draw a single factor, specified by its index i and its variable keys.
|
/// Draw a single factor, specified by its index i and its variable keys.
|
||||||
void processFactor(size_t i, const KeyVector& keys,
|
void processFactor(size_t i, const KeyVector& keys,
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
std::ostream* os) const;
|
std::ostream* os) const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -144,7 +144,7 @@ void FactorGraph<FACTOR>::dot(std::ostream& os,
|
||||||
const auto& factor = at(i);
|
const auto& factor = at(i);
|
||||||
if (factor) {
|
if (factor) {
|
||||||
const KeyVector& factorKeys = factor->keys();
|
const KeyVector& factorKeys = factor->keys();
|
||||||
writer.processFactor(i, factorKeys, boost::none, &os);
|
writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,10 @@
|
||||||
# include <tbb/parallel_for.h>
|
# include <tbb/parallel_for.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
// Create factors and variable connections
|
// Create factors and variable connections
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (const KeyVector& factorKeys : structure) {
|
for (const KeyVector& factorKeys : structure) {
|
||||||
writer.processFactor(i++, factorKeys, boost::none, &os);
|
writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create factors and variable connections
|
// Create factors and variable connections
|
||||||
|
@ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
|
||||||
const NonlinearFactor::shared_ptr& factor = at(i);
|
const NonlinearFactor::shared_ptr& factor = at(i);
|
||||||
if (factor) {
|
if (factor) {
|
||||||
const KeyVector& factorKeys = factor->keys();
|
const KeyVector& factorKeys = factor->keys();
|
||||||
writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os);
|
writer.processFactor(i, factorKeys, keyFormatter,
|
||||||
|
writer.factorPos(min, i), &os);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var7782220156096217089[label=\"l1\"];\n"
|
" varl1[label=\"l1\"];\n"
|
||||||
" var8646911284551352321[label=\"x1\"];\n"
|
" varx1[label=\"x1\"];\n"
|
||||||
" var8646911284551352322[label=\"x2\"];\n"
|
" varx2[label=\"x2\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" factor0[label=\"\", shape=point];\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--factor0;\n"
|
" varx1--factor0;\n"
|
||||||
" var8646911284551352321--var8646911284551352322;\n"
|
" factor1[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--var7782220156096217089;\n"
|
" varx1--factor1;\n"
|
||||||
" var8646911284551352322--var7782220156096217089;\n"
|
" varx2--factor1;\n"
|
||||||
|
" factor2[label=\"\", shape=point];\n"
|
||||||
|
" varx1--factor2;\n"
|
||||||
|
" varl1--factor2;\n"
|
||||||
|
" factor3[label=\"\", shape=point];\n"
|
||||||
|
" varx2--factor3;\n"
|
||||||
|
" varl1--factor3;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
|
|
||||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||||
|
@ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) {
|
||||||
"graph {\n"
|
"graph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
"\n"
|
"\n"
|
||||||
" var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n"
|
" varl1[label=\"l1\", pos=\"0,0!\"];\n"
|
||||||
" var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n"
|
" varx1[label=\"x1\", pos=\"1,0!\"];\n"
|
||||||
" var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n"
|
" varx2[label=\"x2\", pos=\"1,1.5!\"];\n"
|
||||||
"\n"
|
"\n"
|
||||||
" factor0[label=\"\", shape=point];\n"
|
" factor0[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--factor0;\n"
|
" varx1--factor0;\n"
|
||||||
" var8646911284551352321--var8646911284551352322;\n"
|
" factor1[label=\"\", shape=point];\n"
|
||||||
" var8646911284551352321--var7782220156096217089;\n"
|
" varx1--factor1;\n"
|
||||||
" var8646911284551352322--var7782220156096217089;\n"
|
" varx2--factor1;\n"
|
||||||
|
" factor2[label=\"\", shape=point];\n"
|
||||||
|
" varx1--factor2;\n"
|
||||||
|
" varl1--factor2;\n"
|
||||||
|
" factor3[label=\"\", shape=point];\n"
|
||||||
|
" varx2--factor3;\n"
|
||||||
|
" varl1--factor3;\n"
|
||||||
"}\n";
|
"}\n";
|
||||||
|
|
||||||
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
const NonlinearFactorGraph fg = createNonlinearFactorGraph();
|
||||||
|
|
Loading…
Reference in New Issue