Vanilla Conjugate Gradient Descent works

release/4.3a0
Frank Dellaert 2009-12-12 04:44:34 +00:00
parent 2a2963b7dd
commit 99533f286f
5 changed files with 100 additions and 24 deletions

View File

@ -230,12 +230,44 @@ VectorConfig GaussianFactorGraph::optimalUpdate(const VectorConfig& x,
/* ************************************************************************* */ /* ************************************************************************* */
VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const { VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const {
VectorConfig x = x0; VectorConfig x = x0;
int K = 10*x.size(); int maxK = 10*x.dim();
for (int k=0;k<K;k++) { for (int k=0;k<maxK;k++) {
// calculate gradient and check for convergence
VectorConfig g = gradient(x); VectorConfig g = gradient(x);
double dotg = dot(g,g);
if (dotg<1e-9) break;
x = optimalUpdate(x,g); x = optimalUpdate(x,g);
} }
return x; return x;
} }
/* ************************************************************************* */ /* ************************************************************************* */
// directions are actually negative of those in cgd.lyx
VectorConfig GaussianFactorGraph::conjugateGradientDescent(
const VectorConfig& x0) const {
// take first step in direction of the gradient
VectorConfig d = gradient(x0);
VectorConfig x = optimalUpdate(x0,d);
double prev_dotg = dot(d,d);
// loop over remaining (n-1) dimensions
int n = d.dim();
for (int k=2;k<=n;k++) {
// calculate gradient and check for convergence
VectorConfig gk = gradient(x);
double dotg = dot(gk,gk);
if (dotg<1e-9) break;
// calculate new search direction
double beta = dotg/prev_dotg;
prev_dotg = dotg;
d = gk + d * beta;
// do step in new search direction
x = optimalUpdate(x,d);
}
return x;
}
/* ************************************************************************* */

View File

@ -180,6 +180,13 @@ namespace gtsam {
* @return solution * @return solution
*/ */
VectorConfig gradientDescent(const VectorConfig& x0) const; VectorConfig gradientDescent(const VectorConfig& x0) const;
/**
* Find solution using conjugate gradient descent
* @param x0: VectorConfig specifying initial estimate
* @return solution
*/
VectorConfig conjugateGradientDescent(const VectorConfig& x0) const;
}; };
} }

View File

@ -19,6 +19,7 @@ using namespace boost::assign;
#include "Ordering.h" #include "Ordering.h"
#include "smallExample.h" #include "smallExample.h"
#include "GaussianBayesNet.h" #include "GaussianBayesNet.h"
#include "numericalDerivative.h"
#include "inference-inl.h" // needed for eliminate and marginals #include "inference-inl.h" // needed for eliminate and marginals
using namespace gtsam; using namespace gtsam;
@ -538,6 +539,11 @@ TEST( GaussianFactorGraph, involves )
} }
/* ************************************************************************* */ /* ************************************************************************* */
double error(const VectorConfig& x) {
GaussianFactorGraph fg = createGaussianFactorGraph();
return fg.error(x);
}
TEST( GaussianFactorGraph, gradient ) TEST( GaussianFactorGraph, gradient )
{ {
GaussianFactorGraph fg = createGaussianFactorGraph(); GaussianFactorGraph fg = createGaussianFactorGraph();
@ -547,15 +553,19 @@ TEST( GaussianFactorGraph, gradient )
// 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2 // 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2
// worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5] // worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5]
expected.insert("x1",Vector_(2,30.0,5.0));
expected.insert("x2",Vector_(2,-25.0, 17.5));
expected.insert("l1",Vector_(2, 5.0,-12.5)); expected.insert("l1",Vector_(2, 5.0,-12.5));
expected.insert("x1",Vector_(2, 30.0, 5.0));
expected.insert("x2",Vector_(2,-25.0, 17.5));
// Check the gradient at delta=0 // Check the gradient at delta=0
VectorConfig zero = createZeroDelta(); VectorConfig zero = createZeroDelta();
VectorConfig actual = fg.gradient(zero); VectorConfig actual = fg.gradient(zero);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected,actual));
// Check it numerically for good measure
Vector numerical_g = numericalGradient<VectorConfig>(error,zero,0.001);
CHECK(assert_equal(Vector_(6,5.0,-12.5,30.0,5.0,-25.0,17.5),numerical_g));
// Check the gradient at the solution (should be zero) // Check the gradient at the solution (should be zero)
Ordering ord; Ordering ord;
ord += "x2","l1","x1"; ord += "x2","l1","x1";
@ -579,6 +589,10 @@ TEST( GaussianFactorGraph, gradientDescent )
VectorConfig zero = createZeroDelta(); VectorConfig zero = createZeroDelta();
VectorConfig actual = fg2.gradientDescent(zero); VectorConfig actual = fg2.gradientDescent(zero);
CHECK(assert_equal(expected,actual,1e-2)); CHECK(assert_equal(expected,actual,1e-2));
// Do conjugate gradient descent
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
CHECK(assert_equal(expected,actual2,1e-2));
} }
/* ************************************************************************* */ /* ************************************************************************* */

39
matlab/frank01.m Normal file
View File

@ -0,0 +1,39 @@
%-----------------------------------------------------------------------
% frank01.m: try conjugate gradient on our example graph
%-----------------------------------------------------------------------
% get matrix form H and z
fg = createGaussianFactorGraph();
ord = Ordering;
ord.push_back('x1');
ord.push_back('x2');
ord.push_back('l1');
[H,z] = fg.matrix(ord);
% form system of normal equations
A=H'*H
b=H'*z
% k=0
x = zeros(6,1)
g = A*x-b
d = -g
for k=1:5
alpha = - (d'*g)/(d'*A*d)
x = x + alpha*d
g = A*x-b
beta = (d'*A*g)/(d'*A*d)
d = -g + beta*d
end
% Do gradient descent
% fg2 = createGaussianFactorGraph();
% zero = createZeroDelta();
% actual = fg2.gradientDescent(zero);
% CHECK(assert_equal(expected,actual,1e-2));
% Do conjugate gradient descent
% actual2 = fg2.conjugateGradientDescent(zero);
% CHECK(assert_equal(expected,actual2,1e-2));

View File

@ -6,8 +6,8 @@ CHECK('equals',fg.equals(fg2,1e-9));
%----------------------------------------------------------------------- %-----------------------------------------------------------------------
% error % error
cfg = createZeroDelta(); zero = createZeroDelta();
actual = fg.error(cfg); actual = fg.error(zero);
DOUBLES_EQUAL( 5.625, actual, 1e-9 ); DOUBLES_EQUAL( 5.625, actual, 1e-9 );
%----------------------------------------------------------------------- %-----------------------------------------------------------------------
@ -60,25 +60,9 @@ CHECK('eliminateAll', actual1.equals(expected,1e-5));
fg = createGaussianFactorGraph(); fg = createGaussianFactorGraph();
ord = Ordering; ord = Ordering;
ord.push_back('x1');
ord.push_back('x2'); ord.push_back('x2');
ord.push_back('l1'); ord.push_back('l1');
ord.push_back('x1');
A = fg.matrix(ord); [H,z] = fg.matrix(ord);
%-----------------------------------------------------------------------
% gradientDescent
% Expected solution
fg = createGaussianFactorGraph();
expected = fg.optimize_(ord); % destructive
% Do gradient descent
% fg2 = createGaussianFactorGraph();
% zero = createZeroDelta();
% actual = fg2.gradientDescent(zero);
% CHECK(assert_equal(expected,actual,1e-2));
% Do conjugate gradient descent
% actual2 = fg2.conjugateGradientDescent(zero);
% CHECK(assert_equal(expected,actual2,1e-2));