diff --git a/cpp/GaussianFactorGraph.cpp b/cpp/GaussianFactorGraph.cpp index fb8053c2b..341f6bc9d 100644 --- a/cpp/GaussianFactorGraph.cpp +++ b/cpp/GaussianFactorGraph.cpp @@ -230,12 +230,44 @@ VectorConfig GaussianFactorGraph::optimalUpdate(const VectorConfig& x, /* ************************************************************************* */ VectorConfig GaussianFactorGraph::gradientDescent(const VectorConfig& x0) const { VectorConfig x = x0; - int K = 10*x.size(); - for (int k=0;k(error,zero,0.001); + CHECK(assert_equal(Vector_(6,5.0,-12.5,30.0,5.0,-25.0,17.5),numerical_g)); + // Check the gradient at the solution (should be zero) Ordering ord; ord += "x2","l1","x1"; @@ -579,6 +589,10 @@ TEST( GaussianFactorGraph, gradientDescent ) VectorConfig zero = createZeroDelta(); VectorConfig actual = fg2.gradientDescent(zero); CHECK(assert_equal(expected,actual,1e-2)); + + // Do conjugate gradient descent + VectorConfig actual2 = fg2.conjugateGradientDescent(zero); + CHECK(assert_equal(expected,actual2,1e-2)); } /* ************************************************************************* */ diff --git a/matlab/frank01.m b/matlab/frank01.m new file mode 100644 index 000000000..0bf9ab5fe --- /dev/null +++ b/matlab/frank01.m @@ -0,0 +1,39 @@ +%----------------------------------------------------------------------- +% frank01.m: try conjugate gradient on our example graph +%----------------------------------------------------------------------- + +% get matrix form H and z +fg = createGaussianFactorGraph(); +ord = Ordering; +ord.push_back('x1'); +ord.push_back('x2'); +ord.push_back('l1'); + +[H,z] = fg.matrix(ord); + +% form system of normal equations +A=H'*H +b=H'*z + +% k=0 +x = zeros(6,1) +g = A*x-b +d = -g + +for k=1:5 + alpha = - (d'*g)/(d'*A*d) + x = x + alpha*d + g = A*x-b + beta = (d'*A*g)/(d'*A*d) + d = -g + beta*d +end + +% Do gradient descent +% fg2 = createGaussianFactorGraph(); +% zero = createZeroDelta(); +% actual = fg2.gradientDescent(zero); +% CHECK(assert_equal(expected,actual,1e-2)); + +% Do conjugate gradient descent +% actual2 = fg2.conjugateGradientDescent(zero); +% CHECK(assert_equal(expected,actual2,1e-2)); diff --git a/matlab/testGaussianFactorGraph.m b/matlab/testGaussianFactorGraph.m index 7bfe9e1d6..b77a3fd1e 100644 --- a/matlab/testGaussianFactorGraph.m +++ b/matlab/testGaussianFactorGraph.m @@ -6,8 +6,8 @@ CHECK('equals',fg.equals(fg2,1e-9)); %----------------------------------------------------------------------- % error -cfg = createZeroDelta(); -actual = fg.error(cfg); +zero = createZeroDelta(); +actual = fg.error(zero); DOUBLES_EQUAL( 5.625, actual, 1e-9 ); %----------------------------------------------------------------------- @@ -60,25 +60,9 @@ CHECK('eliminateAll', actual1.equals(expected,1e-5)); fg = createGaussianFactorGraph(); ord = Ordering; +ord.push_back('x1'); ord.push_back('x2'); ord.push_back('l1'); -ord.push_back('x1'); -A = fg.matrix(ord); +[H,z] = fg.matrix(ord); -%----------------------------------------------------------------------- -% gradientDescent - -% Expected solution -fg = createGaussianFactorGraph(); -expected = fg.optimize_(ord); % destructive - -% Do gradient descent -% fg2 = createGaussianFactorGraph(); -% zero = createZeroDelta(); -% actual = fg2.gradientDescent(zero); -% CHECK(assert_equal(expected,actual,1e-2)); - -% Do conjugate gradient descent -% actual2 = fg2.conjugateGradientDescent(zero); -% CHECK(assert_equal(expected,actual2,1e-2));