diff --git a/cpp/GaussianBayesNet.cpp b/cpp/GaussianBayesNet.cpp index 6f3f550e1..91d488894 100644 --- a/cpp/GaussianBayesNet.cpp +++ b/cpp/GaussianBayesNet.cpp @@ -99,7 +99,7 @@ void backSubstituteInPlace(const GaussianBayesNet& bn, VectorConfig& y) { const Symbol& j = it->first; const Matrix& Rij = it->second; Vector& xj = x.getReference(j); - axpy(-1.0, Rij*xj, zi); // TODO: use BLAS level 2 + multiplyAdd(-1.0,Rij,xj,zi); } Vector& xi = x.getReference(i); xi = gtsam::backSubstituteUpper(cg->get_R(), zi); @@ -132,8 +132,7 @@ VectorConfig backSubstituteTranspose(const GaussianBayesNet& bn, const Symbol& i = it->first; const Matrix& Rij = it->second; Vector& gyi = gy.getReference(i); // should never fail - Matrix Lji = trans(Rij); // TODO avoid transpose of matrix ? - gyi -= Lji * gyj; + transposeMultiplyAdd(-1.0,Rij,gyj,gyi); } } diff --git a/cpp/GaussianConditional.cpp b/cpp/GaussianConditional.cpp index 18cac8c75..f838b310b 100644 --- a/cpp/GaussianConditional.cpp +++ b/cpp/GaussianConditional.cpp @@ -98,7 +98,7 @@ Vector GaussianConditional::solve(const VectorConfig& x) const { for (Parents::const_iterator it = parents_.begin(); it!= parents_.end(); it++) { const Symbol& j = it->first; const Matrix& Aj = it->second; - axpy(-1, Aj * x[j], rhs); // TODO use BLAS level 2 + multiplyAdd(-1.0,Aj,x[j],rhs); } return backSubstituteUpper(R_, rhs, false); } diff --git a/cpp/GaussianFactor.cpp b/cpp/GaussianFactor.cpp index 0747c2d43..f37c555b1 100644 --- a/cpp/GaussianFactor.cpp +++ b/cpp/GaussianFactor.cpp @@ -218,7 +218,7 @@ void GaussianFactor::transposeMultiplyAdd(double alpha, const Vector& e, FOREACH_PAIR(j, Aj, As_) { Vector& Xj = x.getReference(*j); - gtsam::transposeMultiplyAdd(*Aj, E, Xj); + gtsam::transposeMultiplyAdd(1.0, *Aj, E, Xj); } } diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 694ecd804..3906d018f 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -133,6 +133,18 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) { return false; } +/* ************************************************************************* */ +void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e) { + // ublas e += prod(A,x) is terribly slow + // TODO: use BLAS + for (int i = 0; i < A.size1(); i++) { + double& ei = e(i); + for (int j = 0; j < A.size2(); j++) { + ei += alpha * A(i, j) * x(j); + } + } +} + /* ************************************************************************* */ Vector operator^(const Matrix& A, const Vector & v) { if (A.size1()!=v.size()) throw std::invalid_argument( @@ -143,13 +155,13 @@ Vector operator^(const Matrix& A, const Vector & v) { } /* ************************************************************************* */ -void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) { - // ublas Xj += prod(trans(Aj),Ei) is terribly slow +void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector& x) { + // ublas x += prod(trans(A),e) is terribly slow // TODO: use BLAS for (int j = 0; j < A.size2(); j++) { - double& Xj1 = x(j); + double& xj = x(j); for (int i = 0; i < A.size1(); i++) { - Xj1 += A(i, j) * e(i); + xj += alpha * A(i, j) * e(i); } } } diff --git a/cpp/Matrix.h b/cpp/Matrix.h index 8ca8b3a65..030dc6b5b 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -86,6 +86,11 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9); */ inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);} +/** + * BLAS Level-2 style e <- e + alpha*A*x + */ +void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e); + /** * overload ^ for trans(A)*v * We transpose the vectors for speed. @@ -93,9 +98,9 @@ inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);} Vector operator^(const Matrix& A, const Vector & v); /** - * BLAS Level-2 style x <- x + A'*e + * BLAS Level-2 style x <- x + alpha*A'*e */ -void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x); +void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector& x); /** * overload * for vector*matrix multiplication (as BOOST does not) diff --git a/cpp/iterative.h b/cpp/iterative.h index 07394563d..1db7c0dbb 100644 --- a/cpp/iterative.h +++ b/cpp/iterative.h @@ -63,7 +63,7 @@ namespace gtsam { /** x += alpha* A'*e */ inline void transposeMultiplyAdd(double alpha, const Vector& e, Vector& x) const { - gtsam::transposeMultiplyAdd(A_,alpha*e,x); + gtsam::transposeMultiplyAdd(alpha,A_,e,x); } /** diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index fb5cdbf8f..011970682 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -772,6 +772,21 @@ TEST( matrix, square_root_positive ) CHECK(assert_equal(cov, prod(trans(actual),actual))); } +/* ************************************************************************* */ +TEST( matrix, multiplyAdd ) +{ + Matrix A = Matrix_(3,4, + 4., 0., 0., 1., + 0., 4., 0., 2., + 0., 0., 1., 3. + ); + Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.), + expected = e + prod(A, x); + + multiplyAdd(1,A,x,e); + CHECK(assert_equal(expected, e)); +} + /* ************************************************************************* */ TEST( matrix, transposeMultiplyAdd ) { @@ -783,7 +798,7 @@ TEST( matrix, transposeMultiplyAdd ) Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.), expected = x + prod(trans(A), e); - transposeMultiplyAdd(A,e,x); + transposeMultiplyAdd(1,A,e,x); CHECK(assert_equal(expected, x)); }