System version of CG

release/4.3a0
Frank Dellaert 2009-12-26 21:25:45 +00:00
parent 8d3918e7f9
commit 6614434b83
1 changed files with 49 additions and 11 deletions

View File

@ -585,17 +585,45 @@ TEST( GaussianFactorGraph, multiplication )
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected,actual));
} }
/* ************************************************************************* */
typedef pair<Matrix,Vector> System;
/**
* gradient of objective function 0.5*|Ax-b|^2 at x = A'*(Ax-b)
*/
Vector gradient(const System& Ab, const Vector& x) {
const Matrix& A = Ab.first;
const Vector& b = Ab.second;
return A ^ (A * x - b);
}
/**
* Apply operator A
*/
Vector operator*(const System& Ab, const Vector& x) {
const Matrix& A = Ab.first;
return A * x;
}
/**
* Apply operator A^T
*/
Vector operator^(const System& Ab, const Vector& x) {
const Matrix& A = Ab.first;
return A ^ x;
}
/* ************************************************************************* */ /* ************************************************************************* */
// Method of conjugate gradients (CG) // Method of conjugate gradients (CG)
// "Matrix" class M needs A*v and A^e = trans(A)*v // "System" class S needs gradient(S,v), e=S*v, v=S^e
// "Matrix" class E needs dot(v,v), -v, v+v
// "Vector" class V needs dot(v,v), -v, v+v, s*v // "Vector" class V needs dot(v,v), -v, v+v, s*v
template<class M, class E, class V> // "Vector" class E needs dot(v,v)
V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) { template <class S, class V, class E>
Vector conjugateGradientDescent(const S& Ab, V x, double threshold = 1e-9) {
// Start with g0 = A'*(A*x0-b), d0 = - g0 // Start with g0 = A'*(A*x0-b), d0 = - g0
// i.e., first step is in direction of negative gradient // i.e., first step is in direction of negative gradient
V g = A ^ (-b + A * x); V g = gradient(Ab, x);
V d = -g; V d = -g;
double prev_dotg = dot(g, g); double prev_dotg = dot(g, g);
@ -604,15 +632,15 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9)
for (int k = 1; k <= n; k++) { for (int k = 1; k <= n; k++) {
// calculate optimal step-size // calculate optimal step-size
E Ad = A * d; E Ad = Ab * d;
double alpha = -dot(d, g) / dot(Ad, Ad); double alpha = -dot(d, g) / dot(Ad, Ad);
// do step in new search direction // do step in new search direction
x = x + alpha * d; x = x + alpha * d;
if (k==n) break; if (k == n) break;
// update gradient // update gradient
g = g + alpha * V(A ^ Ad); g = g + alpha * (Ab ^ Ad);
// check for convergence // check for convergence
double dotg = dot(g, g); double dotg = dot(g, g);
@ -623,10 +651,17 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9)
prev_dotg = dotg; prev_dotg = dotg;
d = -g + beta * d; d = -g + beta * d;
} }
return x; return x;
} }
/* ************************************************************************* */
// Method of conjugate gradients (CG)
Vector conjugateGradientDescent(const Matrix& A, const Vector& b,
const Vector& x, double threshold = 1e-9) {
System Ab = make_pair(A, b);
return conjugateGradientDescent<System, Vector, Vector> (Ab, x);
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( GaussianFactorGraph, gradientDescent ) TEST( GaussianFactorGraph, gradientDescent )
{ {
@ -656,12 +691,15 @@ TEST( GaussianFactorGraph, gradientDescent )
Vector actualX = conjugateGradientDescent(A,b,x0); Vector actualX = conjugateGradientDescent(A,b,x0);
Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2); Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2);
CHECK(assert_equal(expectedX,actualX,1e-9)); CHECK(assert_equal(expectedX,actualX,1e-9));
// Do conjugate gradient descent, System version
System Ab = make_pair(A,b);
Vector actualX2 = conjugateGradientDescent<System,Vector,Vector>(Ab,x0);
CHECK(assert_equal(expectedX,actualX2,1e-9));
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Tests ported from ConstrainedGaussianFactorGraph // Tests ported from ConstrainedGaussianFactorGraph
/* ************************************************************************* */
/* ************************************************************************* */ /* ************************************************************************* */
TEST( GaussianFactorGraph, constrained_simple ) TEST( GaussianFactorGraph, constrained_simple )
{ {