Added position hints
parent
7ccee875fe
commit
87eeb0d27e
|
@ -16,9 +16,11 @@
|
||||||
* @date December, 2021
|
* @date December, 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -59,18 +61,35 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ConnectVariables(Key key1, Key key2,
|
static void ConnectVariables(Key key1, Key key2,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter, ostream* os) {
|
||||||
ostream* os) {
|
|
||||||
*os << " var" << keyFormatter(key1) << "--"
|
*os << " var" << keyFormatter(key1) << "--"
|
||||||
<< "var" << keyFormatter(key2) << ";\n";
|
<< "var" << keyFormatter(key2) << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
|
||||||
size_t i, ostream* os) {
|
size_t i, ostream* os) {
|
||||||
*os << " var" << keyFormatter(key) << "--"
|
*os << " var" << keyFormatter(key) << "--"
|
||||||
<< "factor" << i << ";\n";
|
<< "factor" << i << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return variable position or none
|
||||||
|
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
|
||||||
|
boost::optional<Vector2> result = boost::none;
|
||||||
|
|
||||||
|
// Check position hint
|
||||||
|
Symbol symbol(key);
|
||||||
|
auto hint = positionHints.find(symbol.chr());
|
||||||
|
if (hint != positionHints.end())
|
||||||
|
result.reset(Vector2(symbol.index(), hint->second));
|
||||||
|
|
||||||
|
// Override with explicit position, if given.
|
||||||
|
auto pos = variablePositions.find(key);
|
||||||
|
if (pos != variablePositions.end())
|
||||||
|
result.reset(pos->second);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
void DotWriter::processFactor(size_t i, const KeyVector& keys,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const boost::optional<Vector2>& position,
|
const boost::optional<Vector2>& position,
|
||||||
|
|
|
@ -39,9 +39,19 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
///< the dot of the factor
|
///< the dot of the factor
|
||||||
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
bool binaryEdges; ///< just use non-dotted edges for binary factors
|
||||||
|
|
||||||
/// (optional for each variable) Manually specify variable node positions
|
/**
|
||||||
|
* Variable positions can be optionally specified and will be included in the
|
||||||
|
* dor file with a "!' sign, so "neato" can use it to render them.
|
||||||
|
*/
|
||||||
std::map<gtsam::Key, Vector2> variablePositions;
|
std::map<gtsam::Key, Vector2> variablePositions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The position hints allow one to use symbol character and index to specify
|
||||||
|
* position. Unless variable positions are specified, if a hint is present for
|
||||||
|
* a given symbol, it will be used to calculate the positions as (index,hint).
|
||||||
|
*/
|
||||||
|
std::map<char, double> positionHints;
|
||||||
|
|
||||||
explicit DotWriter(double figureWidthInches = 5,
|
explicit DotWriter(double figureWidthInches = 5,
|
||||||
double figureHeightInches = 5,
|
double figureHeightInches = 5,
|
||||||
bool plotFactorPoints = true,
|
bool plotFactorPoints = true,
|
||||||
|
@ -68,13 +78,7 @@ struct GTSAM_EXPORT DotWriter {
|
||||||
std::ostream* os);
|
std::ostream* os);
|
||||||
|
|
||||||
/// Return variable position or none
|
/// Return variable position or none
|
||||||
boost::optional<Vector2> variablePos(Key key) const {
|
boost::optional<Vector2> variablePos(Key key) const;
|
||||||
auto it = variablePositions.find(key);
|
|
||||||
if (it == variablePositions.end())
|
|
||||||
return boost::none;
|
|
||||||
else
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Draw a single factor, specified by its index i and its variable keys.
|
/// Draw a single factor, specified by its index i and its variable keys.
|
||||||
void processFactor(size_t i, const KeyVector& keys,
|
void processFactor(size_t i, const KeyVector& keys,
|
||||||
|
|
|
@ -129,6 +129,7 @@ class DotWriter {
|
||||||
bool binaryEdges;
|
bool binaryEdges;
|
||||||
|
|
||||||
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
|
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
|
||||||
|
std::map<char, double> positionHints;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/VariableIndex.h>
|
#include <gtsam/inference/VariableIndex.h>
|
||||||
|
|
|
@ -312,7 +312,6 @@ TEST(GaussianBayesNet, Dot) {
|
||||||
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
|
||||||
|
|
||||||
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
|
||||||
noisyBayesNet.saveGraph("noisyBayesNet.dot", DefaultKeyFormatter, writer);
|
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph {\n"
|
"digraph {\n"
|
||||||
" size=\"5,5\";\n"
|
" size=\"5,5\";\n"
|
||||||
|
|
|
@ -43,12 +43,14 @@ namespace gtsam {
|
||||||
class ExpressionFactor;
|
class ExpressionFactor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors,
|
* A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
|
||||||
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general
|
* which derive from NonlinearFactor. The values structures are typically (in
|
||||||
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds.
|
* SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
|
||||||
* Linearizing the non-linear factor graph creates a linear factor graph on the
|
* in non-linear manifolds. Linearizing the non-linear factor graph creates a
|
||||||
* tangent vector space at the linearization point. Because the tangent space is a true
|
* linear factor graph on the tangent vector space at the linearization point.
|
||||||
* vector space, the config type will be an VectorValues in that linearized factor graph.
|
* Because the tangent space is a true vector space, the config type will be
|
||||||
|
* an VectorValues in that linearized factor graph.
|
||||||
|
* @addtogroup nonlinear
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
|
||||||
|
|
||||||
|
@ -58,6 +60,9 @@ namespace gtsam {
|
||||||
typedef NonlinearFactorGraph This;
|
typedef NonlinearFactorGraph This;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Default constructor */
|
/** Default constructor */
|
||||||
NonlinearFactorGraph() {}
|
NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
@ -76,6 +81,10 @@ namespace gtsam {
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~NonlinearFactorGraph() {}
|
virtual ~NonlinearFactorGraph() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(
|
void print(
|
||||||
const std::string& str = "NonlinearFactorGraph: ",
|
const std::string& str = "NonlinearFactorGraph: ",
|
||||||
|
@ -90,6 +99,10 @@ namespace gtsam {
|
||||||
/** Test equality */
|
/** Test equality */
|
||||||
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
|
||||||
double error(const Values& values) const;
|
double error(const Values& values) const;
|
||||||
|
|
||||||
|
@ -206,6 +219,7 @@ namespace gtsam {
|
||||||
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
emplace_shared<PriorFactor<T>>(key, prior, covariance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
/// @name Graph Display
|
/// @name Graph Display
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -250,6 +264,8 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated
|
||||||
|
/// @{
|
||||||
/** @deprecated */
|
/** @deprecated */
|
||||||
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
|
||||||
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
|
||||||
|
@ -261,19 +277,20 @@ namespace gtsam {
|
||||||
{return updateCholesky(values, dampen);}
|
{return updateCholesky(values, dampen);}
|
||||||
|
|
||||||
/** @deprecated */
|
/** @deprecated */
|
||||||
void GTSAM_DEPRECATED
|
void GTSAM_DEPRECATED saveGraph(
|
||||||
saveGraph(std::ostream& os, const Values& values = Values(),
|
std::ostream& os, const Values& values = Values(),
|
||||||
const GraphvizFormatting& writer = GraphvizFormatting(),
|
const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(),
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
dot(os, values, keyFormatter, graphvizFormatting);
|
dot(os, values, keyFormatter, graphvizFormatting);
|
||||||
}
|
}
|
||||||
/** @deprecated */
|
/** @deprecated */
|
||||||
void GTSAM_DEPRECATED
|
void GTSAM_DEPRECATED
|
||||||
saveGraph(const std::string& filename, const Values& values,
|
saveGraph(const std::string& filename, const Values& values,
|
||||||
const GraphvizFormatting& writer,
|
const GraphvizFormatting& graphvizFormatting,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
saveGraph(filename, values, keyFormatter, graphvizFormatting);
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,13 +15,16 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/base/VectorSpace.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/symbolic/SymbolicConditional.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <boost/make_shared.hpp>
|
||||||
#include <gtsam/symbolic/SymbolicBayesNet.h>
|
|
||||||
#include <gtsam/symbolic/SymbolicConditional.h>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -30,7 +33,6 @@ static const Key _L_ = 0;
|
||||||
static const Key _A_ = 1;
|
static const Key _A_ = 1;
|
||||||
static const Key _B_ = 2;
|
static const Key _B_ = 2;
|
||||||
static const Key _C_ = 3;
|
static const Key _C_ = 3;
|
||||||
static const Key _D_ = 4;
|
|
||||||
|
|
||||||
static SymbolicConditional::shared_ptr
|
static SymbolicConditional::shared_ptr
|
||||||
B(new SymbolicConditional(_B_)),
|
B(new SymbolicConditional(_B_)),
|
||||||
|
@ -78,14 +80,38 @@ TEST( SymbolicBayesNet, combine )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(SymbolicBayesNet, saveGraph) {
|
TEST(SymbolicBayesNet, Dot) {
|
||||||
|
using symbol_shorthand::A;
|
||||||
|
using symbol_shorthand::X;
|
||||||
SymbolicBayesNet bn;
|
SymbolicBayesNet bn;
|
||||||
bn += SymbolicConditional(_A_, _B_);
|
bn += SymbolicConditional(X(3), X(2), A(2));
|
||||||
KeyVector keys {_B_, _C_, _D_};
|
bn += SymbolicConditional(X(2), X(1), A(1));
|
||||||
bn += SymbolicConditional::FromKeys(keys,2);
|
bn += SymbolicConditional(X(1));
|
||||||
bn += SymbolicConditional(_D_);
|
|
||||||
|
|
||||||
bn.saveGraph("SymbolicBayesNet.dot");
|
DotWriter writer;
|
||||||
|
writer.positionHints.emplace('a', 2);
|
||||||
|
writer.positionHints.emplace('x', 1);
|
||||||
|
|
||||||
|
auto position = writer.variablePos(A(1));
|
||||||
|
CHECK(position);
|
||||||
|
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
|
||||||
|
|
||||||
|
string actual = bn.dot(DefaultKeyFormatter, writer);
|
||||||
|
EXPECT(actual ==
|
||||||
|
"digraph {\n"
|
||||||
|
" size=\"5,5\";\n"
|
||||||
|
"\n"
|
||||||
|
" vara1[label=\"a1\", pos=\"1,2!\"];\n"
|
||||||
|
" vara2[label=\"a2\", pos=\"2,2!\"];\n"
|
||||||
|
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
|
||||||
|
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
|
||||||
|
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
|
||||||
|
"\n"
|
||||||
|
" varx1->varx2\n"
|
||||||
|
" vara1->varx2\n"
|
||||||
|
" varx2->varx3\n"
|
||||||
|
" vara2->varx3\n"
|
||||||
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue