diff --git a/cpp/testGaussianFactorGraph.cpp b/cpp/testGaussianFactorGraph.cpp index 5a9339e52..79cd9ed07 100644 --- a/cpp/testGaussianFactorGraph.cpp +++ b/cpp/testGaussianFactorGraph.cpp @@ -585,17 +585,45 @@ TEST( GaussianFactorGraph, multiplication ) CHECK(assert_equal(expected,actual)); } +/* ************************************************************************* */ +typedef pair System; + +/** + * gradient of objective function 0.5*|Ax-b|^2 at x = A'*(Ax-b) + */ +Vector gradient(const System& Ab, const Vector& x) { + const Matrix& A = Ab.first; + const Vector& b = Ab.second; + return A ^ (A * x - b); +} + +/** + * Apply operator A + */ +Vector operator*(const System& Ab, const Vector& x) { + const Matrix& A = Ab.first; + return A * x; +} + +/** + * Apply operator A^T + */ +Vector operator^(const System& Ab, const Vector& x) { + const Matrix& A = Ab.first; + return A ^ x; +} + /* ************************************************************************* */ // Method of conjugate gradients (CG) -// "Matrix" class M needs A*v and A^e = trans(A)*v -// "Matrix" class E needs dot(v,v), -v, v+v +// "System" class S needs gradient(S,v), e=S*v, v=S^e // "Vector" class V needs dot(v,v), -v, v+v, s*v -template -V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) { +// "Vector" class E needs dot(v,v) +template +Vector conjugateGradientDescent(const S& Ab, V x, double threshold = 1e-9) { // Start with g0 = A'*(A*x0-b), d0 = - g0 // i.e., first step is in direction of negative gradient - V g = A ^ (-b + A * x); + V g = gradient(Ab, x); V d = -g; double prev_dotg = dot(g, g); @@ -604,15 +632,15 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) for (int k = 1; k <= n; k++) { // calculate optimal step-size - E Ad = A * d; + E Ad = Ab * d; double alpha = -dot(d, g) / dot(Ad, Ad); // do step in new search direction x = x + alpha * d; - if (k==n) break; + if (k == n) break; // update gradient - g = g + alpha * V(A ^ Ad); + g = g + alpha * (Ab ^ Ad); // check for convergence double dotg = dot(g, g); @@ -623,10 +651,17 @@ V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) prev_dotg = dotg; d = -g + beta * d; } - return x; } +/* ************************************************************************* */ +// Method of conjugate gradients (CG) +Vector conjugateGradientDescent(const Matrix& A, const Vector& b, + const Vector& x, double threshold = 1e-9) { + System Ab = make_pair(A, b); + return conjugateGradientDescent (Ab, x); +} + /* ************************************************************************* */ TEST( GaussianFactorGraph, gradientDescent ) { @@ -656,12 +691,15 @@ TEST( GaussianFactorGraph, gradientDescent ) Vector actualX = conjugateGradientDescent(A,b,x0); Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2); CHECK(assert_equal(expectedX,actualX,1e-9)); + + // Do conjugate gradient descent, System version + System Ab = make_pair(A,b); + Vector actualX2 = conjugateGradientDescent(Ab,x0); + CHECK(assert_equal(expectedX,actualX2,1e-9)); } /* ************************************************************************* */ // Tests ported from ConstrainedGaussianFactorGraph -/* ************************************************************************* */ - /* ************************************************************************* */ TEST( GaussianFactorGraph, constrained_simple ) {