From 5e4b23df599527e21e6f1c127c4d39eabcbb426d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 31 Jan 2010 16:04:24 +0000 Subject: [PATCH] Matrix::multiplyAdd and transposeMultiplyAdd are "level 2" BLAS and speed up the numeric part of the code substantially. Alex might be able to speed them up even more by making them use real BLAS code within Matrix.cpp. --- cpp/GaussianBayesNet.cpp | 5 ++--- cpp/GaussianConditional.cpp | 2 +- cpp/GaussianFactor.cpp | 2 +- cpp/Matrix.cpp | 20 ++++++++++++++++---- cpp/Matrix.h | 9 +++++++-- cpp/iterative.h | 2 +- cpp/testMatrix.cpp | 17 ++++++++++++++++- 7 files changed, 44 insertions(+), 13 deletions(-) diff --git a/cpp/GaussianBayesNet.cpp b/cpp/GaussianBayesNet.cpp index 6f3f550e1..91d488894 100644 --- a/cpp/GaussianBayesNet.cpp +++ b/cpp/GaussianBayesNet.cpp @@ -99,7 +99,7 @@ void backSubstituteInPlace(const GaussianBayesNet& bn, VectorConfig& y) { const Symbol& j = it->first; const Matrix& Rij = it->second; Vector& xj = x.getReference(j); - axpy(-1.0, Rij*xj, zi); // TODO: use BLAS level 2 + multiplyAdd(-1.0,Rij,xj,zi); } Vector& xi = x.getReference(i); xi = gtsam::backSubstituteUpper(cg->get_R(), zi); @@ -132,8 +132,7 @@ VectorConfig backSubstituteTranspose(const GaussianBayesNet& bn, const Symbol& i = it->first; const Matrix& Rij = it->second; Vector& gyi = gy.getReference(i); // should never fail - Matrix Lji = trans(Rij); // TODO avoid transpose of matrix ? - gyi -= Lji * gyj; + transposeMultiplyAdd(-1.0,Rij,gyj,gyi); } } diff --git a/cpp/GaussianConditional.cpp b/cpp/GaussianConditional.cpp index 18cac8c75..f838b310b 100644 --- a/cpp/GaussianConditional.cpp +++ b/cpp/GaussianConditional.cpp @@ -98,7 +98,7 @@ Vector GaussianConditional::solve(const VectorConfig& x) const { for (Parents::const_iterator it = parents_.begin(); it!= parents_.end(); it++) { const Symbol& j = it->first; const Matrix& Aj = it->second; - axpy(-1, Aj * x[j], rhs); // TODO use BLAS level 2 + multiplyAdd(-1.0,Aj,x[j],rhs); } return backSubstituteUpper(R_, rhs, false); } diff --git a/cpp/GaussianFactor.cpp b/cpp/GaussianFactor.cpp index 0747c2d43..f37c555b1 100644 --- a/cpp/GaussianFactor.cpp +++ b/cpp/GaussianFactor.cpp @@ -218,7 +218,7 @@ void GaussianFactor::transposeMultiplyAdd(double alpha, const Vector& e, FOREACH_PAIR(j, Aj, As_) { Vector& Xj = x.getReference(*j); - gtsam::transposeMultiplyAdd(*Aj, E, Xj); + gtsam::transposeMultiplyAdd(1.0, *Aj, E, Xj); } } diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 694ecd804..3906d018f 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -133,6 +133,18 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) { return false; } +/* ************************************************************************* */ +void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e) { + // ublas e += prod(A,x) is terribly slow + // TODO: use BLAS + for (int i = 0; i < A.size1(); i++) { + double& ei = e(i); + for (int j = 0; j < A.size2(); j++) { + ei += alpha * A(i, j) * x(j); + } + } +} + /* ************************************************************************* */ Vector operator^(const Matrix& A, const Vector & v) { if (A.size1()!=v.size()) throw std::invalid_argument( @@ -143,13 +155,13 @@ Vector operator^(const Matrix& A, const Vector & v) { } /* ************************************************************************* */ -void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) { - // ublas Xj += prod(trans(Aj),Ei) is terribly slow +void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector& x) { + // ublas x += prod(trans(A),e) is terribly slow // TODO: use BLAS for (int j = 0; j < A.size2(); j++) { - double& Xj1 = x(j); + double& xj = x(j); for (int i = 0; i < A.size1(); i++) { - Xj1 += A(i, j) * e(i); + xj += alpha * A(i, j) * e(i); } } } diff --git a/cpp/Matrix.h b/cpp/Matrix.h index 8ca8b3a65..030dc6b5b 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -86,6 +86,11 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9); */ inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);} +/** + * BLAS Level-2 style e <- e + alpha*A*x + */ +void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e); + /** * overload ^ for trans(A)*v * We transpose the vectors for speed. @@ -93,9 +98,9 @@ inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);} Vector operator^(const Matrix& A, const Vector & v); /** - * BLAS Level-2 style x <- x + A'*e + * BLAS Level-2 style x <- x + alpha*A'*e */ -void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x); +void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector& x); /** * overload * for vector*matrix multiplication (as BOOST does not) diff --git a/cpp/iterative.h b/cpp/iterative.h index 07394563d..1db7c0dbb 100644 --- a/cpp/iterative.h +++ b/cpp/iterative.h @@ -63,7 +63,7 @@ namespace gtsam { /** x += alpha* A'*e */ inline void transposeMultiplyAdd(double alpha, const Vector& e, Vector& x) const { - gtsam::transposeMultiplyAdd(A_,alpha*e,x); + gtsam::transposeMultiplyAdd(alpha,A_,e,x); } /** diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index fb5cdbf8f..011970682 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -772,6 +772,21 @@ TEST( matrix, square_root_positive ) CHECK(assert_equal(cov, prod(trans(actual),actual))); } +/* ************************************************************************* */ +TEST( matrix, multiplyAdd ) +{ + Matrix A = Matrix_(3,4, + 4., 0., 0., 1., + 0., 4., 0., 2., + 0., 0., 1., 3. + ); + Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.), + expected = e + prod(A, x); + + multiplyAdd(1,A,x,e); + CHECK(assert_equal(expected, e)); +} + /* ************************************************************************* */ TEST( matrix, transposeMultiplyAdd ) { @@ -783,7 +798,7 @@ TEST( matrix, transposeMultiplyAdd ) Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.), expected = x + prod(trans(A), e); - transposeMultiplyAdd(A,e,x); + transposeMultiplyAdd(1,A,e,x); CHECK(assert_equal(expected, x)); }