diff --git a/cpp/Matrix.h b/cpp/Matrix.h index f296102d6..c2c95d4ab 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -85,7 +85,18 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9); */ inline Vector operator*(const Matrix& A, const Vector & v) { if (A.size2()!=v.size()) throw(std::invalid_argument("Matrix operator* : A.n!=v.size")); - return Vector(prod(A,v)); + return prod(A,v); +} + +/** + * overload ^ for trans(A)*v + * We transpose the vectors for speed. + */ +inline Vector operator^(const Matrix& A, const Vector & v) { + if (A.size1()!=v.size()) throw(std::invalid_argument("Matrix operator^ : A.m!=v.size")); + Vector vt = trans(v); + Vector vtA = prod(vt,A); + return trans(vtA); } /** @@ -93,7 +104,7 @@ inline Vector operator*(const Matrix& A, const Vector & v) { */ inline Vector operator*(const Vector & v, const Matrix& A) { if (A.size1()!=v.size()) throw(std::invalid_argument("Matrix operator* : A.m!=v.size")); - return Vector(prod(v,A)); + return prod(v,A); } /** diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index b3926fbfa..c18d1f0f6 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -337,16 +337,12 @@ TEST( matrix, matrix_vector_multiplication ) 1.0,2.0,3.0, 4.0,5.0,6.0 ); - Vector v(3); - v(0) = 1.0; - v(1) = 2.0; - v(2) = 3.0; - - Vector Av(2); - Av(0) = 14.0; - Av(1) = 32.0; + Vector v = Vector_(3,1.,2.,3.); + Vector Av = Vector_(2,14.,32.); + Vector AtAv = Vector_(3,142.,188.,234.); EQUALITY(A*v,Av); + EQUALITY(A^Av,AtAv); } /* ************************************************************************* */