Conjugate Gradient Descent template (in progress)

release/4.3a0
Frank Dellaert 2009-12-26 15:06:54 +00:00
parent 886c7dcdcc
commit f3965b07ca
1 changed files with 64 additions and 1 deletions

View File

@ -544,6 +544,7 @@ double error(const VectorConfig& x) {
return fg.error(x);
}
/* ************************************************************************* */
TEST( GaussianFactorGraph, gradient )
{
GaussianFactorGraph fg = createGaussianFactorGraph();
@ -575,12 +576,63 @@ TEST( GaussianFactorGraph, gradient )
CHECK(assert_equal(zero,actual2));
}
/* ************************************************************************* *
TEST( GaussianFactorGraph, multiplication )
{
GaussianFactorGraph A = createGaussianFactorGraph();
VectorConfig x = createConfig();
ErrorConfig actual = A * x;
CHECK(assert_equal(expected,actual));
}
/* ************************************************************************* */
// Method of conjugate gradients (CG)
// "Matrix" class M needs A*v and A^e = trans(A)*v
// "Matrix" class E needs dot(v,v), -v, v+v
// "Vector" class V needs dot(v,v), -v, v+v, s*v
template<class M, class E, class V>
V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) {
// Start with g0 = A'*(A*x0-b), d0 = - g0
// i.e., first step is in direction of negative gradient
V g = A ^ (-b + A * x);
V d = -g;
double prev_dotg = dot(g, g);
// loop max n times
size_t n = x.size();
for (int k = 1; k <= n; k++) {
// calculate optimal step-size
E Ad = A * d;
double alpha = -dot(d, g) / dot(Ad, Ad);
// do step in new search direction
x = x + alpha * d;
if (k==n) break;
// update gradient
g = g + alpha * V(A ^ Ad);
// check for convergence
double dotg = dot(g, g);
if (dotg < threshold) break;
// calculate new search direction
double beta = dotg / prev_dotg;
prev_dotg = dotg;
d = -g + beta * d;
}
return x;
}
/* ************************************************************************* */
TEST( GaussianFactorGraph, gradientDescent )
{
// Expected solution
Ordering ord;
ord += "x2","l1","x1";
ord += "l1","x1","x2";
GaussianFactorGraph fg = createGaussianFactorGraph();
VectorConfig expected = fg.optimize(ord); // destructive
@ -592,7 +644,18 @@ TEST( GaussianFactorGraph, gradientDescent )
// Do conjugate gradient descent
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
//VectorConfig actual2 = conjugateGradientDescent(fg2,zero,zero);
CHECK(assert_equal(expected,actual2,1e-2));
// Do conjugate gradient descent, Matrix version
Matrix A;Vector b;
boost::tie(A,b) = fg2.matrix(ord);
// print(A,"A");
// print(b,"b");
Vector x0 = gtsam::zero(6);
Vector actualX = conjugateGradientDescent(A,b,x0);
Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2);
CHECK(assert_equal(expectedX,actualX,1e-9));
}
/* ************************************************************************* */