diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 897079267..694ecd804 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -133,20 +133,26 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) { return false; } -/* ************************************************************************ */ -/** negation */ -/* ************************************************************************ */ -/* -Matrix operator-() const -{ - size_t m = size1(),n=size2(); - Matrix M(m,n); - for(size_t i = 0; i < m; i++) - for(size_t j = 0; j < n; j++) - M(i,j) = -matrix_(i,j); - return M; +/* ************************************************************************* */ +Vector operator^(const Matrix& A, const Vector & v) { + if (A.size1()!=v.size()) throw std::invalid_argument( + boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size())); + Vector vt = trans(v); + Vector vtA = prod(vt,A); + return trans(vtA); +} + +/* ************************************************************************* */ +void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) { + // ublas Xj += prod(trans(Aj),Ei) is terribly slow + // TODO: use BLAS + for (int j = 0; j < A.size2(); j++) { + double& Xj1 = x(j); + for (int i = 0; i < A.size1(); i++) { + Xj1 += A(i, j) * e(i); + } + } } -*/ /* ************************************************************************* */ Vector Vector_(const Matrix& A) diff --git a/cpp/Matrix.h b/cpp/Matrix.h index d8cd2942d..8ca8b3a65 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -84,42 +84,28 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9); /** * overload * for matrix-vector multiplication (as BOOST does not) */ -inline Vector operator*(const Matrix& A, const Vector & v) { - if (A.size2()!=v.size()) throw std::invalid_argument( - boost::str(boost::format("Matrix operator* : A.n(%d)!=v.size(%d)") % A.size2() % v.size())); - return prod(A,v); -} +inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);} /** * overload ^ for trans(A)*v * We transpose the vectors for speed. */ -inline Vector operator^(const Matrix& A, const Vector & v) { - if (A.size1()!=v.size()) throw std::invalid_argument( - boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size())); - Vector vt = trans(v); - Vector vtA = prod(vt,A); - return trans(vtA); -} +Vector operator^(const Matrix& A, const Vector & v); + +/** + * BLAS Level-2 style x <- x + A'*e + */ +void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x); /** * overload * for vector*matrix multiplication (as BOOST does not) */ -inline Vector operator*(const Vector & v, const Matrix& A) { - if (A.size1()!=v.size()) throw std::invalid_argument( - boost::str(boost::format("Matrix operator* : A.m(%d)!=v.size(%d)") % A.size1() % v.size())); - return prod(v,A); -} +inline Vector operator*(const Vector & v, const Matrix& A) { return prod(v,A);} /** * overload * for matrix multiplication (as BOOST does not) */ -inline Matrix operator*(const Matrix& A, const Matrix& B) { - // richard: boost already does this check in debug mode I think -// if (A.size2()!=B.size1()) throw std::invalid_argument( -// boost::str(boost::format("Matrix operator* : A.n(%d)!=B.m(%d)") % A.size2() % B.size1())); - return prod(A,B); -} +inline Matrix operator*(const Matrix& A, const Matrix& B) { return prod(A,B);} /** * convert to column vector, column order !!! diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index 2c374b8b2..fb5cdbf8f 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -772,6 +772,21 @@ TEST( matrix, square_root_positive ) CHECK(assert_equal(cov, prod(trans(actual),actual))); } +/* ************************************************************************* */ +TEST( matrix, transposeMultiplyAdd ) +{ + Matrix A = Matrix_(3,4, + 4., 0., 0., 1., + 0., 4., 0., 2., + 0., 0., 1., 3. + ); + Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.), + expected = x + prod(trans(A), e); + + transposeMultiplyAdd(A,e,x); + CHECK(assert_equal(expected, x)); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */