Vanilla Conjugate Gradient Descent works
parent
2a2963b7dd
commit
99533f286f
|
@ -230,12 +230,44 @@ VectorConfig GaussianFactorGraph::optimalUpdate(const VectorConfig& x,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const {
|
VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const {
|
||||||
VectorConfig x = x0;
|
VectorConfig x = x0;
|
||||||
int K = 10*x.size();
|
int maxK = 10*x.dim();
|
||||||
for (int k=0;k<K;k++) {
|
for (int k=0;k<maxK;k++) {
|
||||||
|
// calculate gradient and check for convergence
|
||||||
VectorConfig g = gradient(x);
|
VectorConfig g = gradient(x);
|
||||||
|
double dotg = dot(g,g);
|
||||||
|
if (dotg<1e-9) break;
|
||||||
x = optimalUpdate(x,g);
|
x = optimalUpdate(x,g);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// directions are actually negative of those in cgd.lyx
|
||||||
|
VectorConfig GaussianFactorGraph::conjugateGradientDescent(
|
||||||
|
const VectorConfig& x0) const {
|
||||||
|
|
||||||
|
// take first step in direction of the gradient
|
||||||
|
VectorConfig d = gradient(x0);
|
||||||
|
VectorConfig x = optimalUpdate(x0,d);
|
||||||
|
double prev_dotg = dot(d,d);
|
||||||
|
|
||||||
|
// loop over remaining (n-1) dimensions
|
||||||
|
int n = d.dim();
|
||||||
|
for (int k=2;k<=n;k++) {
|
||||||
|
// calculate gradient and check for convergence
|
||||||
|
VectorConfig gk = gradient(x);
|
||||||
|
double dotg = dot(gk,gk);
|
||||||
|
if (dotg<1e-9) break;
|
||||||
|
|
||||||
|
// calculate new search direction
|
||||||
|
double beta = dotg/prev_dotg;
|
||||||
|
prev_dotg = dotg;
|
||||||
|
d = gk + d * beta;
|
||||||
|
|
||||||
|
// do step in new search direction
|
||||||
|
x = optimalUpdate(x,d);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -180,6 +180,13 @@ namespace gtsam {
|
||||||
* @return solution
|
* @return solution
|
||||||
*/
|
*/
|
||||||
VectorConfig gradientDescent(const VectorConfig& x0) const;
|
VectorConfig gradientDescent(const VectorConfig& x0) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find solution using conjugate gradient descent
|
||||||
|
* @param x0: VectorConfig specifying initial estimate
|
||||||
|
* @return solution
|
||||||
|
*/
|
||||||
|
VectorConfig conjugateGradientDescent(const VectorConfig& x0) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ using namespace boost::assign;
|
||||||
#include "Ordering.h"
|
#include "Ordering.h"
|
||||||
#include "smallExample.h"
|
#include "smallExample.h"
|
||||||
#include "GaussianBayesNet.h"
|
#include "GaussianBayesNet.h"
|
||||||
|
#include "numericalDerivative.h"
|
||||||
#include "inference-inl.h" // needed for eliminate and marginals
|
#include "inference-inl.h" // needed for eliminate and marginals
|
||||||
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -538,6 +539,11 @@ TEST( GaussianFactorGraph, involves )
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
double error(const VectorConfig& x) {
|
||||||
|
GaussianFactorGraph fg = createGaussianFactorGraph();
|
||||||
|
return fg.error(x);
|
||||||
|
}
|
||||||
|
|
||||||
TEST( GaussianFactorGraph, gradient )
|
TEST( GaussianFactorGraph, gradient )
|
||||||
{
|
{
|
||||||
GaussianFactorGraph fg = createGaussianFactorGraph();
|
GaussianFactorGraph fg = createGaussianFactorGraph();
|
||||||
|
@ -547,15 +553,19 @@ TEST( GaussianFactorGraph, gradient )
|
||||||
|
|
||||||
// 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2
|
// 2*f(x) = 100*(x1+c["x1"])^2 + 100*(x2-x1-[0.2;-0.1])^2 + 25*(l1-x1-[0.0;0.2])^2 + 25*(l1-x2-[-0.2;0.3])^2
|
||||||
// worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5]
|
// worked out: df/dx1 = 100*[0.1;0.1] + 100*[0.2;-0.1]) + 25*[0.0;0.2] = [10+20;10-10+5] = [30;5]
|
||||||
expected.insert("x1",Vector_(2,30.0,5.0));
|
|
||||||
expected.insert("x2",Vector_(2,-25.0, 17.5));
|
|
||||||
expected.insert("l1",Vector_(2, 5.0,-12.5));
|
expected.insert("l1",Vector_(2, 5.0,-12.5));
|
||||||
|
expected.insert("x1",Vector_(2, 30.0, 5.0));
|
||||||
|
expected.insert("x2",Vector_(2,-25.0, 17.5));
|
||||||
|
|
||||||
// Check the gradient at delta=0
|
// Check the gradient at delta=0
|
||||||
VectorConfig zero = createZeroDelta();
|
VectorConfig zero = createZeroDelta();
|
||||||
VectorConfig actual = fg.gradient(zero);
|
VectorConfig actual = fg.gradient(zero);
|
||||||
CHECK(assert_equal(expected,actual));
|
CHECK(assert_equal(expected,actual));
|
||||||
|
|
||||||
|
// Check it numerically for good measure
|
||||||
|
Vector numerical_g = numericalGradient<VectorConfig>(error,zero,0.001);
|
||||||
|
CHECK(assert_equal(Vector_(6,5.0,-12.5,30.0,5.0,-25.0,17.5),numerical_g));
|
||||||
|
|
||||||
// Check the gradient at the solution (should be zero)
|
// Check the gradient at the solution (should be zero)
|
||||||
Ordering ord;
|
Ordering ord;
|
||||||
ord += "x2","l1","x1";
|
ord += "x2","l1","x1";
|
||||||
|
@ -579,6 +589,10 @@ TEST( GaussianFactorGraph, gradientDescent )
|
||||||
VectorConfig zero = createZeroDelta();
|
VectorConfig zero = createZeroDelta();
|
||||||
VectorConfig actual = fg2.gradientDescent(zero);
|
VectorConfig actual = fg2.gradientDescent(zero);
|
||||||
CHECK(assert_equal(expected,actual,1e-2));
|
CHECK(assert_equal(expected,actual,1e-2));
|
||||||
|
|
||||||
|
// Do conjugate gradient descent
|
||||||
|
VectorConfig actual2 = fg2.conjugateGradientDescent(zero);
|
||||||
|
CHECK(assert_equal(expected,actual2,1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
%-----------------------------------------------------------------------
|
||||||
|
% frank01.m: try conjugate gradient on our example graph
|
||||||
|
%-----------------------------------------------------------------------
|
||||||
|
|
||||||
|
% get matrix form H and z
|
||||||
|
fg = createGaussianFactorGraph();
|
||||||
|
ord = Ordering;
|
||||||
|
ord.push_back('x1');
|
||||||
|
ord.push_back('x2');
|
||||||
|
ord.push_back('l1');
|
||||||
|
|
||||||
|
[H,z] = fg.matrix(ord);
|
||||||
|
|
||||||
|
% form system of normal equations
|
||||||
|
A=H'*H
|
||||||
|
b=H'*z
|
||||||
|
|
||||||
|
% k=0
|
||||||
|
x = zeros(6,1)
|
||||||
|
g = A*x-b
|
||||||
|
d = -g
|
||||||
|
|
||||||
|
for k=1:5
|
||||||
|
alpha = - (d'*g)/(d'*A*d)
|
||||||
|
x = x + alpha*d
|
||||||
|
g = A*x-b
|
||||||
|
beta = (d'*A*g)/(d'*A*d)
|
||||||
|
d = -g + beta*d
|
||||||
|
end
|
||||||
|
|
||||||
|
% Do gradient descent
|
||||||
|
% fg2 = createGaussianFactorGraph();
|
||||||
|
% zero = createZeroDelta();
|
||||||
|
% actual = fg2.gradientDescent(zero);
|
||||||
|
% CHECK(assert_equal(expected,actual,1e-2));
|
||||||
|
|
||||||
|
% Do conjugate gradient descent
|
||||||
|
% actual2 = fg2.conjugateGradientDescent(zero);
|
||||||
|
% CHECK(assert_equal(expected,actual2,1e-2));
|
|
@ -6,8 +6,8 @@ CHECK('equals',fg.equals(fg2,1e-9));
|
||||||
|
|
||||||
%-----------------------------------------------------------------------
|
%-----------------------------------------------------------------------
|
||||||
% error
|
% error
|
||||||
cfg = createZeroDelta();
|
zero = createZeroDelta();
|
||||||
actual = fg.error(cfg);
|
actual = fg.error(zero);
|
||||||
DOUBLES_EQUAL( 5.625, actual, 1e-9 );
|
DOUBLES_EQUAL( 5.625, actual, 1e-9 );
|
||||||
|
|
||||||
%-----------------------------------------------------------------------
|
%-----------------------------------------------------------------------
|
||||||
|
@ -60,25 +60,9 @@ CHECK('eliminateAll', actual1.equals(expected,1e-5));
|
||||||
|
|
||||||
fg = createGaussianFactorGraph();
|
fg = createGaussianFactorGraph();
|
||||||
ord = Ordering;
|
ord = Ordering;
|
||||||
|
ord.push_back('x1');
|
||||||
ord.push_back('x2');
|
ord.push_back('x2');
|
||||||
ord.push_back('l1');
|
ord.push_back('l1');
|
||||||
ord.push_back('x1');
|
|
||||||
|
|
||||||
A = fg.matrix(ord);
|
[H,z] = fg.matrix(ord);
|
||||||
|
|
||||||
%-----------------------------------------------------------------------
|
|
||||||
% gradientDescent
|
|
||||||
|
|
||||||
% Expected solution
|
|
||||||
fg = createGaussianFactorGraph();
|
|
||||||
expected = fg.optimize_(ord); % destructive
|
|
||||||
|
|
||||||
% Do gradient descent
|
|
||||||
% fg2 = createGaussianFactorGraph();
|
|
||||||
% zero = createZeroDelta();
|
|
||||||
% actual = fg2.gradientDescent(zero);
|
|
||||||
% CHECK(assert_equal(expected,actual,1e-2));
|
|
||||||
|
|
||||||
% Do conjugate gradient descent
|
|
||||||
% actual2 = fg2.conjugateGradientDescent(zero);
|
|
||||||
% CHECK(assert_equal(expected,actual2,1e-2));
|
|
||||||
|
|
Loading…
Reference in New Issue