Vanilla Conjugate Gradient Descent works

release/4.3a0
Frank Dellaert 2009-12-12 04:44:34 +00:00
parent 2a2963b7dd
commit 99533f286f
5 changed files with 100 additions and 24 deletions

View File

@ -230,12 +230,44 @@ VectorConfig GaussianFactorGraph::optimalUpdate(const VectorConfig& x,
/* ************************************************************************* */
VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const {
VectorConfig x = x0;
int K = 10*x.size();
for (int k=0;k<K;k++) {
int maxK = 10*x.dim();
for (int k=0;k<maxK;k++) {
// calculate gradient and check for convergence
VectorConfig g = gradient(x);
double dotg = dot(g,g);
if (dotg<1e-9) break;
x = optimalUpdate(x,g);
}
return x;
}
/* ************************************************************************* */
// directions are actually negative of those in cgd.lyx
VectorConfig GaussianFactorGraph::conjugateGradientDescent(
const VectorConfig& x0) const {
// take first step in direction of the gradient
VectorConfig d = gradient(x0);
VectorConfig x = optimalUpdate(x0,d);
double prev_dotg = dot(d,d);
// loop over remaining (n-1) dimensions
int n = d.dim();
for (int k=2;k<=n;k++) {
// calculate gradient and check for convergence
VectorConfig gk = gradient(x);
double dotg = dot(gk,gk);
if (dotg<1e-9) break;
// calculate new search direction
double beta = dotg/prev_dotg;
prev_dotg = dotg;
d = gk + d * beta;
// do step in new search direction
x = optimalUpdate(x,d);
}
return x;
}
/* ************************************************************************* */

View File

@ -180,6 +180,13 @@ namespace gtsam {
* @return solution
*/
VectorConfig gradientDescent(const VectorConfig& x0) const;
/**
* Find solution using conjugate gradient descent
* @param x0: VectorConfig specifying initial estimate
* @return solution
*/
VectorConfig conjugateGradientDescent(const VectorConfig& x0) const;
};
}

View File

@ -19,6 +19,7 @@ using namespace boost::assign;
#include "Ordering.h"
#include "smallExample.h"
#include "GaussianBayesNet.h"
#include "numericalDerivative.h"
#include "inference-inl.h" // needed for eliminate and marginals
using namespace gtsam;
@ -538,6 +539,11 @@ TEST( GaussianFactorGraph, involves )
}
/* ************************************************************************* */
double error(const VectorConfig& x) {
GaussianFactorGraph fg = createGaussianFactorGraph();
return fg.error(x);
}
TEST( GaussianFactorGraph, gradient )
{
GaussianFactorGraph fg = createGaussianFactorGraph();
@ -547,15 +553,19 @@ TEST( GaussianFactorGraph, gradient )
// 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2
// worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5]
expected.insert("l1",Vector_(2, 5.0,-12.5));
expected.insert("x1",Vector_(2, 30.0, 5.0));
expected.insert("x2",Vector_(2,-25.0, 17.5));
expected.insert("l1",Vector_(2, 5.0,-12.5));
// Check the gradient at delta=0
VectorConfig zero = createZeroDelta();
VectorConfig actual = fg.gradient(zero);
CHECK(assert_equal(expected,actual));
// Check it numerically for good measure
Vector numerical_g = numericalGradient<VectorConfig>(error,zero,0.001);
CHECK(assert_equal(Vector_(6,5.0,-12.5,30.0,5.0,-25.0,17.5),numerical_g));
// Check the gradient at the solution (should be zero)
Ordering ord;
ord += "x2","l1","x1";
@ -579,6 +589,10 @@ TEST( GaussianFactorGraph, gradientDescent )
VectorConfig zero = createZeroDelta();
VectorConfig actual = fg2.gradientDescent(zero);
CHECK(assert_equal(expected,actual,1e-2));
// Do conjugate gradient descent
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
CHECK(assert_equal(expected,actual2,1e-2));
}
/* ************************************************************************* */

39
matlab/frank01.m Normal file
View File

@ -0,0 +1,39 @@
%-----------------------------------------------------------------------
% frank01.m: try conjugate gradient on our example graph
%-----------------------------------------------------------------------
% get matrix form H and z
fg = createGaussianFactorGraph();
ord = Ordering;
ord.push_back('x1');
ord.push_back('x2');
ord.push_back('l1');
[H,z] = fg.matrix(ord);
% form system of normal equations
A=H'*H
b=H'*z
% k=0
x = zeros(6,1)
g = A*x-b
d = -g
for k=1:5
alpha = - (d'*g)/(d'*A*d)
x = x + alpha*d
g = A*x-b
beta = (d'*A*g)/(d'*A*d)
d = -g + beta*d
end
% Do gradient descent
% fg2 = createGaussianFactorGraph();
% zero = createZeroDelta();
% actual = fg2.gradientDescent(zero);
% CHECK(assert_equal(expected,actual,1e-2));
% Do conjugate gradient descent
% actual2 = fg2.conjugateGradientDescent(zero);
% CHECK(assert_equal(expected,actual2,1e-2));

View File

@ -6,8 +6,8 @@ CHECK('equals',fg.equals(fg2,1e-9));
%-----------------------------------------------------------------------
% error
cfg = createZeroDelta();
actual = fg.error(cfg);
zero = createZeroDelta();
actual = fg.error(zero);
DOUBLES_EQUAL( 5.625, actual, 1e-9 );
%-----------------------------------------------------------------------
@ -60,25 +60,9 @@ CHECK('eliminateAll', actual1.equals(expected,1e-5));
fg = createGaussianFactorGraph();
ord = Ordering;
ord.push_back('x1');
ord.push_back('x2');
ord.push_back('l1');
ord.push_back('x1');
A = fg.matrix(ord);
[H,z] = fg.matrix(ord);
%-----------------------------------------------------------------------
% gradientDescent
% Expected solution
fg = createGaussianFactorGraph();
expected = fg.optimize_(ord); % destructive
% Do gradient descent
% fg2 = createGaussianFactorGraph();
% zero = createZeroDelta();
% actual = fg2.gradientDescent(zero);
% CHECK(assert_equal(expected,actual,1e-2));
% Do conjugate gradient descent
% actual2 = fg2.conjugateGradientDescent(zero);
% CHECK(assert_equal(expected,actual2,1e-2));