From 7c3131b53335a399b409579e227a4f3563e7ba06 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 27 Oct 2013 14:58:51 +0000 Subject: [PATCH] Made multiplyHessian into multiplyHessianAdd --- gtsam/linear/GaussianFactorGraph.cpp | 7 +++---- gtsam/linear/GaussianFactorGraph.h | 5 +++-- .../testGaussianFactorGraphUnordered.cpp | 20 +++++++++++++------ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index a7d7de0be..0cc3556db 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -247,11 +247,10 @@ namespace gtsam { } /* ************************************************************************* */ - VectorValues GaussianFactorGraph::multiplyHessian(const VectorValues& x) const { - VectorValues y; + void GaussianFactorGraph::multiplyHessianAdd(double alpha, + const VectorValues& x, VectorValues& y) const { BOOST_FOREACH(const GaussianFactor::shared_ptr& f, *this) - f->multiplyHessianAdd(1.0,x,y); - return y; + f->multiplyHessianAdd(alpha, x, y); } /* ************************************************************************* */ diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index 76f97607b..e52321cea 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -269,8 +269,9 @@ namespace gtsam { ///** return A*x */ Errors operator*(const VectorValues& x) const; - ///** return A'A*x */ - VectorValues multiplyHessian(const VectorValues& x) const; + ///** y += alpha*A'A*x */ + void multiplyHessianAdd(double alpha, const VectorValues& x, + VectorValues& y) const; ///** In-place version e <- A*x that overwrites e. */ void multiplyInPlace(const VectorValues& x, Errors& e) const; diff --git a/gtsam/linear/tests/testGaussianFactorGraphUnordered.cpp b/gtsam/linear/tests/testGaussianFactorGraphUnordered.cpp index 1ba31dec1..8b1874552 100644 --- a/gtsam/linear/tests/testGaussianFactorGraphUnordered.cpp +++ b/gtsam/linear/tests/testGaussianFactorGraphUnordered.cpp @@ -142,8 +142,6 @@ TEST(GaussianFactorGraph, matrices) { } /* ************************************************************************* */ -static Key X1=2,X2=0,L1=1; - static GaussianFactorGraph createSimpleGaussianFactorGraph() { GaussianFactorGraph fg; SharedDiagonal unit2 = noiseModel::Unit::Create(2); @@ -224,7 +222,7 @@ TEST(GaussianFactorGraph, eliminate_empty ) } /* ************************************************************************* */ -TEST( GaussianFactorGraph, multiplyHessian ) +TEST( GaussianFactorGraph, multiplyHessianAdd ) { GaussianFactorGraph A = createSimpleGaussianFactorGraph(); @@ -238,8 +236,13 @@ TEST( GaussianFactorGraph, multiplyHessian ) expected.insert(1, (Vec(2) << 0, 0)); expected.insert(2, (Vec(2) << 950, 1050)); - VectorValues actual = A.multiplyHessian(x); + VectorValues actual; + A.multiplyHessianAdd(1.0, x, actual); EXPECT(assert_equal(expected, actual)); + + // now, do it with non-zero y + A.multiplyHessianAdd(1.0, x, actual); + EXPECT(assert_equal(2*expected, actual)); } /* ************************************************************************* */ @@ -251,7 +254,7 @@ static GaussianFactorGraph createGaussianFactorGraphWithHessianFactor() { } /* ************************************************************************* */ -TEST( GaussianFactorGraph, multiplyHessian2 ) +TEST( GaussianFactorGraph, multiplyHessianAdd2 ) { GaussianFactorGraph A = createGaussianFactorGraphWithHessianFactor(); @@ -266,8 +269,13 @@ TEST( GaussianFactorGraph, multiplyHessian2 ) expected.insert(1, (Vec(2) << 2900, 2900)); expected.insert(2, (Vec(2) << 6750, 6850)); - VectorValues actual = A.multiplyHessian(x); + VectorValues actual; + A.multiplyHessianAdd(1.0, x, actual); EXPECT(assert_equal(expected, actual)); + + // now, do it with non-zero y + A.multiplyHessianAdd(1.0, x, actual); + EXPECT(assert_equal(2*expected, actual)); }