System version of CG
parent
8d3918e7f9
commit
6614434b83
|
@ -585,17 +585,45 @@ TEST( GaussianFactorGraph, multiplication )
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected,actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
typedef pair<Matrix,Vector> System;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* gradient of objective function 0.5*|Ax-b|^2 at x = A'*(Ax-b)
|
||||||
|
*/
|
||||||
|
Vector gradient(const System& Ab, const Vector& x) {
|
||||||
|
const Matrix& A = Ab.first;
|
||||||
|
const Vector& b = Ab.second;
|
||||||
|
return A ^ (A * x - b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply operator A
|
||||||
|
*/
|
||||||
|
Vector operator*(const System& Ab, const Vector& x) {
|
||||||
|
const Matrix& A = Ab.first;
|
||||||
|
return A * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply operator A^T
|
||||||
|
*/
|
||||||
|
Vector operator^(const System& Ab, const Vector& x) {
|
||||||
|
const Matrix& A = Ab.first;
|
||||||
|
return A ^ x;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Method of conjugate gradients (CG)
|
// Method of conjugate gradients (CG)
|
||||||
// "Matrix" class M needs A*v and A^e = trans(A)*v
|
// "System" class S needs gradient(S,v), e=S*v, v=S^e
|
||||||
// "Matrix" class E needs dot(v,v), -v, v+v
|
|
||||||
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
||||||
template<class M, class E, class V>
|
// "Vector" class E needs dot(v,v)
|
||||||
V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) {
|
template <class S, class V, class E>
|
||||||
|
Vector conjugateGradientDescent(const S& Ab, V x, double threshold = 1e-9) {
|
||||||
|
|
||||||
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
||||||
// i.e., first step is in direction of negative gradient
|
// i.e., first step is in direction of negative gradient
|
||||||
V g = A ^ (-b + A * x);
|
V g = gradient(Ab, x);
|
||||||
V d = -g;
|
V d = -g;
|
||||||
double prev_dotg = dot(g, g);
|
double prev_dotg = dot(g, g);
|
||||||
|
|
||||||
|
@ -604,15 +632,15 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9)
|
||||||
for (int k = 1; k <= n; k++) {
|
for (int k = 1; k <= n; k++) {
|
||||||
|
|
||||||
// calculate optimal step-size
|
// calculate optimal step-size
|
||||||
E Ad = A * d;
|
E Ad = Ab * d;
|
||||||
double alpha = -dot(d, g) / dot(Ad, Ad);
|
double alpha = -dot(d, g) / dot(Ad, Ad);
|
||||||
|
|
||||||
// do step in new search direction
|
// do step in new search direction
|
||||||
x = x + alpha * d;
|
x = x + alpha * d;
|
||||||
if (k==n) break;
|
if (k == n) break;
|
||||||
|
|
||||||
// update gradient
|
// update gradient
|
||||||
g = g + alpha * V(A ^ Ad);
|
g = g + alpha * (Ab ^ Ad);
|
||||||
|
|
||||||
// check for convergence
|
// check for convergence
|
||||||
double dotg = dot(g, g);
|
double dotg = dot(g, g);
|
||||||
|
@ -623,10 +651,17 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9)
|
||||||
prev_dotg = dotg;
|
prev_dotg = dotg;
|
||||||
d = -g + beta * d;
|
d = -g + beta * d;
|
||||||
}
|
}
|
||||||
|
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Method of conjugate gradients (CG)
|
||||||
|
Vector conjugateGradientDescent(const Matrix& A, const Vector& b,
|
||||||
|
const Vector& x, double threshold = 1e-9) {
|
||||||
|
System Ab = make_pair(A, b);
|
||||||
|
return conjugateGradientDescent<System, Vector, Vector> (Ab, x);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( GaussianFactorGraph, gradientDescent )
|
TEST( GaussianFactorGraph, gradientDescent )
|
||||||
{
|
{
|
||||||
|
@ -656,12 +691,15 @@ TEST( GaussianFactorGraph, gradientDescent )
|
||||||
Vector actualX = conjugateGradientDescent(A,b,x0);
|
Vector actualX = conjugateGradientDescent(A,b,x0);
|
||||||
Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2);
|
Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2);
|
||||||
CHECK(assert_equal(expectedX,actualX,1e-9));
|
CHECK(assert_equal(expectedX,actualX,1e-9));
|
||||||
|
|
||||||
|
// Do conjugate gradient descent, System version
|
||||||
|
System Ab = make_pair(A,b);
|
||||||
|
Vector actualX2 = conjugateGradientDescent<System,Vector,Vector>(Ab,x0);
|
||||||
|
CHECK(assert_equal(expectedX,actualX2,1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Tests ported from ConstrainedGaussianFactorGraph
|
// Tests ported from ConstrainedGaussianFactorGraph
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( GaussianFactorGraph, constrained_simple )
|
TEST( GaussianFactorGraph, constrained_simple )
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue