Gaussian Factor Graph unittests and linearization

(with Python bindings)
release/4.3a0
Gerry Chen 2019-10-09 15:58:57 -04:00
parent 19315cc3f3
commit e20494324f
5 changed files with 238 additions and 75 deletions

View File

@ -0,0 +1,91 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Linear Factor Graphs.
Author: Frank Dellaert & Gerry Chen
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
import gtsam
from gtsam.utils.test_case import GtsamTestCase
import numpy as np
def create_graph():
"""Create a basic linear factor graph for testing"""
graph = gtsam.GaussianFactorGraph()
x0 = gtsam.symbol(ord('x'), 0)
x1 = gtsam.symbol(ord('x'), 1)
x2 = gtsam.symbol(ord('x'), 2)
BETWEEN_NOISE = gtsam.noiseModel_Diagonal.Sigmas(np.ones(1))
PRIOR_NOISE = gtsam.noiseModel_Diagonal.Sigmas(np.ones(1))
graph.add(x1, np.eye(1), x0, -np.eye(1), np.ones(1), BETWEEN_NOISE)
graph.add(x2, np.eye(1), x1, -np.eye(1), 2*np.ones(1), BETWEEN_NOISE)
graph.add(x0, np.eye(1), np.zeros(1), PRIOR_NOISE)
return graph, (x0, x1, x2)
class TestGaussianFactorGraph(GtsamTestCase):
"""Tests for Gaussian Factor Graphs."""
def test_fg(self):
"""Test solving a linear factor graph"""
graph, X = create_graph()
result = graph.optimize()
EXPECTEDX = [0, 1, 3]
# check solutions
self.assertAlmostEqual(EXPECTEDX[0], result.at(X[0]), delta=1e-8)
self.assertAlmostEqual(EXPECTEDX[1], result.at(X[1]), delta=1e-8)
self.assertAlmostEqual(EXPECTEDX[2], result.at(X[2]), delta=1e-8)
def test_convertNonlinear(self):
"""Test converting a linear factor graph to a nonlinear one"""
graph, X = create_graph()
EXPECTEDM = [1, 2, 3]
# create nonlinear factor graph for marginalization
nfg = gtsam.LinearContainerFactor.ConvertLinearGraph(graph)
optimizer = gtsam.LevenbergMarquardtOptimizer(nfg, gtsam.Values())
nlresult = optimizer.optimizeSafely()
# marginalize
marginals = gtsam.Marginals(nfg, nlresult)
m = [marginals.marginalCovariance(x) for x in X]
# check linear marginalizations
self.assertAlmostEqual(EXPECTEDM[0], m[0], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8)
def test_linearMarginalization(self):
"""Marginalize a linear factor graph"""
graph, X = create_graph()
result = graph.optimize()
EXPECTEDM = [1, 2, 3]
# linear factor graph marginalize
marginals = gtsam.Marginals(graph, result)
m = [marginals.marginalCovariance(x) for x in X]
# check linear marginalizations
self.assertAlmostEqual(EXPECTEDM[0], m[0], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8)
if __name__ == '__main__':
unittest.main()

View File

@ -2011,6 +2011,10 @@ class Values {
class Marginals { class Marginals {
Marginals(const gtsam::NonlinearFactorGraph& graph, Marginals(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& solution); const gtsam::Values& solution);
Marginals(const gtsam::GaussianFactorGraph& gfgraph,
const gtsam::Values& solution);
Marginals(const gtsam::GaussianFactorGraph& gfgraph,
const gtsam::VectorValues& solutionvec);
void print(string s) const; void print(string s) const;
Matrix marginalCovariance(size_t variable) const; Matrix marginalCovariance(size_t variable) const;

View File

@ -28,15 +28,35 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
Marginals::Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization, Marginals::Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering) EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
{ : values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor); gttic(MarginalsConstructor);
// Linearize graph
graph_ = *graph.linearize(solution); graph_ = *graph.linearize(solution);
computeBayesTree(ordering);
}
// Store values /* ************************************************************************* */
values_ = solution; Marginals::Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), factorization_(factorization) {
gttic(MarginalsConstructor);
Values vals;
for (const VectorValues::KeyValuePair& v: solution) {
vals.insert(v.first, v.second);
}
values_ = vals;
computeBayesTree(ordering);
}
/* ************************************************************************* */
Marginals::Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor);
computeBayesTree(ordering);
}
/* ************************************************************************* */
void Marginals::computeBayesTree(EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering) {
// Compute BayesTree // Compute BayesTree
factorization_ = factorization; factorization_ = factorization;
if(factorization_ == CHOLESKY) if(factorization_ == CHOLESKY)

View File

@ -51,7 +51,7 @@ public:
/// Default constructor only for Cython wrapper /// Default constructor only for Cython wrapper
Marginals(){} Marginals(){}
/** Construct a marginals class. /** Construct a marginals class from a nonlinear factor graph.
* @param graph The factor graph defining the full joint density on all variables. * @param graph The factor graph defining the full joint density on all variables.
* @param solution The linearization point about which to compute Gaussian marginals (usually the MLE as obtained from a NonlinearOptimizer). * @param solution The linearization point about which to compute Gaussian marginals (usually the MLE as obtained from a NonlinearOptimizer).
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems). * @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
@ -60,6 +60,24 @@ public:
Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization = CHOLESKY, Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none); EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);
/** Construct a marginals class from a linear factor graph.
* @param graph The factor graph defining the full joint density on all variables.
* @param solution The solution point to compute Gaussian marginals.
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
* @param ordering An optional variable ordering for elimination.
*/
Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);
/** Construct a marginals class from a linear factor graph.
* @param graph The factor graph defining the full joint density on all variables.
* @param solution The solution point to compute Gaussian marginals.
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
* @param ordering An optional variable ordering for elimination.
*/
Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);
/** print */ /** print */
void print(const std::string& str = "Marginals: ", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; void print(const std::string& str = "Marginals: ", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
@ -81,6 +99,12 @@ public:
/** Optimize the bayes tree */ /** Optimize the bayes tree */
VectorValues optimize() const; VectorValues optimize() const;
protected:
/** Compute the Bayes Tree as a helper function to the constructor */
void computeBayesTree(EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering);
}; };
/** /**

View File

@ -87,6 +87,12 @@ TEST(Marginals, planarSLAMmarginals) {
soln.insert(x3, Pose2(4.0, 0.0, 0.0)); soln.insert(x3, Pose2(4.0, 0.0, 0.0));
soln.insert(l1, Point2(2.0, 2.0)); soln.insert(l1, Point2(2.0, 2.0));
soln.insert(l2, Point2(4.0, 2.0)); soln.insert(l2, Point2(4.0, 2.0));
VectorValues soln_lin;
soln_lin.insert(x1, Vector3(0.0, 0.0, 0.0));
soln_lin.insert(x2, Vector3(2.0, 0.0, 0.0));
soln_lin.insert(x3, Vector3(4.0, 0.0, 0.0));
soln_lin.insert(l1, Vector2(2.0, 2.0));
soln_lin.insert(l2, Vector2(4.0, 2.0));
Matrix expectedx1(3,3); Matrix expectedx1(3,3);
expectedx1 << expectedx1 <<
@ -110,71 +116,80 @@ TEST(Marginals, planarSLAMmarginals) {
Matrix expectedl2(2,2); Matrix expectedl2(2,2);
expectedl2 << expectedl2 <<
0.293870968, -0.104516129, 0.293870968, -0.104516129,
-0.104516129, 0.391935484; -0.104516129, 0.391935484;
// Check marginals covariances for all variables (Cholesky mode) auto testMarginals = [&] (Marginals marginals) {
Marginals marginals(graph, soln, Marginals::CHOLESKY); EXPECT(assert_equal(expectedx1, marginals.marginalCovariance(x1), 1e-8));
EXPECT(assert_equal(expectedx1, marginals.marginalCovariance(x1), 1e-8)); EXPECT(assert_equal(expectedx2, marginals.marginalCovariance(x2), 1e-8));
EXPECT(assert_equal(expectedx2, marginals.marginalCovariance(x2), 1e-8)); EXPECT(assert_equal(expectedx3, marginals.marginalCovariance(x3), 1e-8));
EXPECT(assert_equal(expectedx3, marginals.marginalCovariance(x3), 1e-8)); EXPECT(assert_equal(expectedl1, marginals.marginalCovariance(l1), 1e-8));
EXPECT(assert_equal(expectedl1, marginals.marginalCovariance(l1), 1e-8)); EXPECT(assert_equal(expectedl2, marginals.marginalCovariance(l2), 1e-8));
EXPECT(assert_equal(expectedl2, marginals.marginalCovariance(l2), 1e-8)); };
// Check marginals covariances for all variables (QR mode) auto testJointMarginals = [&] (Marginals marginals) {
// Check joint marginals for 3 variables
Matrix expected_l2x1x3(8,8);
expected_l2x1x3 <<
0.293871159514111, -0.104516127560770, 0.090000180000270, -0.000000000000000, -0.020000000000000, 0.151935669757191, -0.104516127560770, -0.050967744878460,
-0.104516127560770, 0.391935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000, 0.007741936219615, 0.351935664055174, 0.056129031890193,
0.090000180000270, 0.000000000000000, 0.090000180000270, -0.000000000000000, 0.000000000000000, 0.090000180000270, 0.000000000000000, 0.000000000000000,
-0.000000000000000, 0.090000180000270, -0.000000000000000, 0.090000180000270, 0.000000000000000, -0.000000000000000, 0.090000180000270, 0.000000000000000,
-0.020000000000000, 0.040000000000000, 0.000000000000000, 0.000000000000000, 0.010000000000000, 0.000000000000000, 0.040000000000000, 0.010000000000000,
0.151935669757191, 0.007741936219615, 0.090000180000270, -0.000000000000000, 0.000000000000000, 0.160967924878730, 0.007741936219615, 0.004516127560770,
-0.104516127560770, 0.351935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000, 0.007741936219615, 0.351935664055174, 0.056129031890193,
-0.050967744878460, 0.056129031890193, 0.000000000000000, 0.000000000000000, 0.010000000000000, 0.004516127560770, 0.056129031890193, 0.027741936219615;
KeyVector variables {x1, l2, x3};
JointMarginal joint_l2x1x3 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,0,2,2)), Matrix(joint_l2x1x3(l2,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,0,3,2)), Matrix(joint_l2x1x3(x1,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,0,3,2)), Matrix(joint_l2x1x3(x3,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,2,2,3)), Matrix(joint_l2x1x3(l2,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,2,3,3)), Matrix(joint_l2x1x3(x1,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,2,3,3)), Matrix(joint_l2x1x3(x3,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,5,2,3)), Matrix(joint_l2x1x3(l2,x3)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,5,3,3)), Matrix(joint_l2x1x3(x1,x3)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,5,3,3)), Matrix(joint_l2x1x3(x3,x3)), 1e-6));
// Check joint marginals for 2 variables (different code path than >2 variable case above)
Matrix expected_l2x1(5,5);
expected_l2x1 <<
0.293871159514111, -0.104516127560770, 0.090000180000270, -0.000000000000000, -0.020000000000000,
-0.104516127560770, 0.391935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000,
0.090000180000270, 0.000000000000000, 0.090000180000270, -0.000000000000000, 0.000000000000000,
-0.000000000000000, 0.090000180000270, -0.000000000000000, 0.090000180000270, 0.000000000000000,
-0.020000000000000, 0.040000000000000, 0.000000000000000, 0.000000000000000, 0.010000000000000;
variables.resize(2);
variables[0] = l2;
variables[1] = x1;
JointMarginal joint_l2x1 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(Matrix(expected_l2x1.block(0,0,2,2)), Matrix(joint_l2x1(l2,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(2,0,3,2)), Matrix(joint_l2x1(x1,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(0,2,2,3)), Matrix(joint_l2x1(l2,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(2,2,3,3)), Matrix(joint_l2x1(x1,x1)), 1e-6));
// Check joint marginal for 1 variable (different code path than >1 variable cases above)
variables.resize(1);
variables[0] = x1;
JointMarginal joint_x1 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(expectedx1, Matrix(joint_l2x1(x1,x1)), 1e-6));
};
Marginals marginals;
marginals = Marginals(graph, soln, Marginals::CHOLESKY);
testMarginals(marginals);
marginals = Marginals(graph, soln, Marginals::QR); marginals = Marginals(graph, soln, Marginals::QR);
EXPECT(assert_equal(expectedx1, marginals.marginalCovariance(x1), 1e-8)); testMarginals(marginals);
EXPECT(assert_equal(expectedx2, marginals.marginalCovariance(x2), 1e-8)); testJointMarginals(marginals);
EXPECT(assert_equal(expectedx3, marginals.marginalCovariance(x3), 1e-8));
EXPECT(assert_equal(expectedl1, marginals.marginalCovariance(l1), 1e-8));
EXPECT(assert_equal(expectedl2, marginals.marginalCovariance(l2), 1e-8));
// Check joint marginals for 3 variables GaussianFactorGraph gfg = *graph.linearize(soln);
Matrix expected_l2x1x3(8,8); marginals = Marginals(gfg, soln_lin, Marginals::CHOLESKY);
expected_l2x1x3 << testMarginals(marginals);
0.293871159514111, -0.104516127560770, 0.090000180000270, -0.000000000000000, -0.020000000000000, 0.151935669757191, -0.104516127560770, -0.050967744878460, marginals = Marginals(gfg, soln_lin, Marginals::QR);
-0.104516127560770, 0.391935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000, 0.007741936219615, 0.351935664055174, 0.056129031890193, testMarginals(marginals);
0.090000180000270, 0.000000000000000, 0.090000180000270, -0.000000000000000, 0.000000000000000, 0.090000180000270, 0.000000000000000, 0.000000000000000, testJointMarginals(marginals);
-0.000000000000000, 0.090000180000270, -0.000000000000000, 0.090000180000270, 0.000000000000000, -0.000000000000000, 0.090000180000270, 0.000000000000000,
-0.020000000000000, 0.040000000000000, 0.000000000000000, 0.000000000000000, 0.010000000000000, 0.000000000000000, 0.040000000000000, 0.010000000000000,
0.151935669757191, 0.007741936219615, 0.090000180000270, -0.000000000000000, 0.000000000000000, 0.160967924878730, 0.007741936219615, 0.004516127560770,
-0.104516127560770, 0.351935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000, 0.007741936219615, 0.351935664055174, 0.056129031890193,
-0.050967744878460, 0.056129031890193, 0.000000000000000, 0.000000000000000, 0.010000000000000, 0.004516127560770, 0.056129031890193, 0.027741936219615;
KeyVector variables {x1, l2, x3};
JointMarginal joint_l2x1x3 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,0,2,2)), Matrix(joint_l2x1x3(l2,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,0,3,2)), Matrix(joint_l2x1x3(x1,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,0,3,2)), Matrix(joint_l2x1x3(x3,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,2,2,3)), Matrix(joint_l2x1x3(l2,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,2,3,3)), Matrix(joint_l2x1x3(x1,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,2,3,3)), Matrix(joint_l2x1x3(x3,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(0,5,2,3)), Matrix(joint_l2x1x3(l2,x3)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(2,5,3,3)), Matrix(joint_l2x1x3(x1,x3)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1x3.block(5,5,3,3)), Matrix(joint_l2x1x3(x3,x3)), 1e-6));
// Check joint marginals for 2 variables (different code path than >2 variable case above)
Matrix expected_l2x1(5,5);
expected_l2x1 <<
0.293871159514111, -0.104516127560770, 0.090000180000270, -0.000000000000000, -0.020000000000000,
-0.104516127560770, 0.391935664055174, 0.000000000000000, 0.090000180000270, 0.040000000000000,
0.090000180000270, 0.000000000000000, 0.090000180000270, -0.000000000000000, 0.000000000000000,
-0.000000000000000, 0.090000180000270, -0.000000000000000, 0.090000180000270, 0.000000000000000,
-0.020000000000000, 0.040000000000000, 0.000000000000000, 0.000000000000000, 0.010000000000000;
variables.resize(2);
variables[0] = l2;
variables[1] = x1;
JointMarginal joint_l2x1 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(Matrix(expected_l2x1.block(0,0,2,2)), Matrix(joint_l2x1(l2,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(2,0,3,2)), Matrix(joint_l2x1(x1,l2)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(0,2,2,3)), Matrix(joint_l2x1(l2,x1)), 1e-6));
EXPECT(assert_equal(Matrix(expected_l2x1.block(2,2,3,3)), Matrix(joint_l2x1(x1,x1)), 1e-6));
// Check joint marginal for 1 variable (different code path than >1 variable cases above)
variables.resize(1);
variables[0] = x1;
JointMarginal joint_x1 = marginals.jointMarginalCovariance(variables);
EXPECT(assert_equal(expectedx1, Matrix(joint_l2x1(x1,x1)), 1e-6));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -222,17 +237,26 @@ TEST(Marginals, order) {
vals.at<Pose2>(3).bearing(vals.at<Point2>(101)), vals.at<Pose2>(3).bearing(vals.at<Point2>(101)),
vals.at<Pose2>(3).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2)); vals.at<Pose2>(3).range(vals.at<Point2>(101)), noiseModel::Unit::Create(2));
auto testMarginals = [&] (Marginals marginals, KeySet set) {
KeyVector keys(set.begin(), set.end());
JointMarginal joint = marginals.jointMarginalCovariance(keys);
LONGS_EQUAL(3, (long)joint(0,0).rows());
LONGS_EQUAL(3, (long)joint(1,1).rows());
LONGS_EQUAL(3, (long)joint(2,2).rows());
LONGS_EQUAL(3, (long)joint(3,3).rows());
LONGS_EQUAL(2, (long)joint(100,100).rows());
LONGS_EQUAL(2, (long)joint(101,101).rows());
};
Marginals marginals(fg, vals); Marginals marginals(fg, vals);
KeySet set = fg.keys(); KeySet set = fg.keys();
KeyVector keys(set.begin(), set.end()); testMarginals(marginals, set);
JointMarginal joint = marginals.jointMarginalCovariance(keys);
LONGS_EQUAL(3, (long)joint(0,0).rows()); GaussianFactorGraph gfg = *fg.linearize(vals);
LONGS_EQUAL(3, (long)joint(1,1).rows()); marginals = Marginals(gfg, vals);
LONGS_EQUAL(3, (long)joint(2,2).rows()); set = gfg.keys();
LONGS_EQUAL(3, (long)joint(3,3).rows()); testMarginals(marginals, set);
LONGS_EQUAL(2, (long)joint(100,100).rows());
LONGS_EQUAL(2, (long)joint(101,101).rows());
} }
/* ************************************************************************* */ /* ************************************************************************* */