Conjugate Gradient Descent template (in progress)
parent
886c7dcdcc
commit
f3965b07ca
|
@ -544,6 +544,7 @@ double error(const VectorConfig& x) {
|
||||||
return fg.error(x);
|
return fg.error(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
TEST( GaussianFactorGraph, gradient )
|
TEST( GaussianFactorGraph, gradient )
|
||||||
{
|
{
|
||||||
GaussianFactorGraph fg = createGaussianFactorGraph();
|
GaussianFactorGraph fg = createGaussianFactorGraph();
|
||||||
|
@ -575,12 +576,63 @@ TEST( GaussianFactorGraph, gradient )
|
||||||
CHECK(assert_equal(zero,actual2));
|
CHECK(assert_equal(zero,actual2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* *
|
||||||
|
TEST( GaussianFactorGraph, multiplication )
|
||||||
|
{
|
||||||
|
GaussianFactorGraph A = createGaussianFactorGraph();
|
||||||
|
VectorConfig x = createConfig();
|
||||||
|
ErrorConfig actual = A * x;
|
||||||
|
CHECK(assert_equal(expected,actual));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Method of conjugate gradients (CG)
|
||||||
|
// "Matrix" class M needs A*v and A^e = trans(A)*v
|
||||||
|
// "Matrix" class E needs dot(v,v), -v, v+v
|
||||||
|
// "Vector" class V needs dot(v,v), -v, v+v, s*v
|
||||||
|
template<class M, class E, class V>
|
||||||
|
V conjugateGradientDescent(const M& A, const E& b, V x, double threshold = 1e-9) {
|
||||||
|
|
||||||
|
// Start with g0 = A'*(A*x0-b), d0 = - g0
|
||||||
|
// i.e., first step is in direction of negative gradient
|
||||||
|
V g = A ^ (-b + A * x);
|
||||||
|
V d = -g;
|
||||||
|
double prev_dotg = dot(g, g);
|
||||||
|
|
||||||
|
// loop max n times
|
||||||
|
size_t n = x.size();
|
||||||
|
for (int k = 1; k <= n; k++) {
|
||||||
|
|
||||||
|
// calculate optimal step-size
|
||||||
|
E Ad = A * d;
|
||||||
|
double alpha = -dot(d, g) / dot(Ad, Ad);
|
||||||
|
|
||||||
|
// do step in new search direction
|
||||||
|
x = x + alpha * d;
|
||||||
|
if (k==n) break;
|
||||||
|
|
||||||
|
// update gradient
|
||||||
|
g = g + alpha * V(A ^ Ad);
|
||||||
|
|
||||||
|
// check for convergence
|
||||||
|
double dotg = dot(g, g);
|
||||||
|
if (dotg < threshold) break;
|
||||||
|
|
||||||
|
// calculate new search direction
|
||||||
|
double beta = dotg / prev_dotg;
|
||||||
|
prev_dotg = dotg;
|
||||||
|
d = -g + beta * d;
|
||||||
|
}
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( GaussianFactorGraph, gradientDescent )
|
TEST( GaussianFactorGraph, gradientDescent )
|
||||||
{
|
{
|
||||||
// Expected solution
|
// Expected solution
|
||||||
Ordering ord;
|
Ordering ord;
|
||||||
ord += "x2","l1","x1";
|
ord += "l1","x1","x2";
|
||||||
GaussianFactorGraph fg = createGaussianFactorGraph();
|
GaussianFactorGraph fg = createGaussianFactorGraph();
|
||||||
VectorConfig expected = fg.optimize(ord); // destructive
|
VectorConfig expected = fg.optimize(ord); // destructive
|
||||||
|
|
||||||
|
@ -592,7 +644,18 @@ TEST( GaussianFactorGraph, gradientDescent )
|
||||||
|
|
||||||
// Do conjugate gradient descent
|
// Do conjugate gradient descent
|
||||||
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
|
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
|
||||||
|
//VectorConfig actual2 = conjugateGradientDescent(fg2,zero,zero);
|
||||||
CHECK(assert_equal(expected,actual2,1e-2));
|
CHECK(assert_equal(expected,actual2,1e-2));
|
||||||
|
|
||||||
|
// Do conjugate gradient descent, Matrix version
|
||||||
|
Matrix A;Vector b;
|
||||||
|
boost::tie(A,b) = fg2.matrix(ord);
|
||||||
|
// print(A,"A");
|
||||||
|
// print(b,"b");
|
||||||
|
Vector x0 = gtsam::zero(6);
|
||||||
|
Vector actualX = conjugateGradientDescent(A,b,x0);
|
||||||
|
Vector expectedX = Vector_(6, -0.1, 0.1, -0.1, -0.1, 0.1, -0.2);
|
||||||
|
CHECK(assert_equal(expectedX,actualX,1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue