diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index c7cfdf02f..765b68586 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -181,17 +181,20 @@ void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e) { /* ************************************************************************* */ void multiplyAdd(const Matrix& A, const Vector& x, Vector& e) { - multiplyAdd(1.0, A, x, e); // ublas e += prod(A,x) is terribly slow -// size_t m = A.size1(), n = A.size2(); -// double * ei = e.data().begin(); -// const double * aij = A.data().begin(); -// for (int i = 0; i < m; i++, ei++) { -// const double * xj = x.data().begin(); -// for (int j = 0; j < n; j++, aij++, xj++) -// (*ei) += (*aij) * (*xj); -// } -} +#ifdef CBLAS + multiplyAdd(1.0, A, x, e); +#else + size_t m = A.size1(), n = A.size2(); + double * ei = e.data().begin(); + const double * aij = A.data().begin(); + for (int i = 0; i < m; i++, ei++) { + const double * xj = x.data().begin(); + for (int j = 0; j < n; j++, aij++, xj++) + (*ei) += (*aij) * (*xj); + } +#endif + } /* ************************************************************************* */ Vector operator^(const Matrix& A, const Vector & v) { @@ -242,17 +245,19 @@ void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector /* ************************************************************************* */ void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) { - transposeMultiplyAdd(1.0, A, e, x); // ublas x += prod(trans(A),e) is terribly slow - // TODO: use BLAS -// size_t m = A.size1(), n = A.size2(); -// double * xj = x.data().begin(); -// for (int j = 0; j < n; j++,xj++) { -// const double * ei = e.data().begin(); -// const double * aij = A.data().begin() + j; -// for (int i = 0; i < m; i++, aij+=n, ei++) -// (*xj) += (*aij) * (*ei); -// } +#ifdef CBLAS + transposeMultiplyAdd(1.0, A, e, x); +#else + size_t m = A.size1(), n = A.size2(); + double * xj = x.data().begin(); + for (int j = 0; j < n; j++,xj++) { + const double * ei = e.data().begin(); + const double * aij = A.data().begin() + j; + for (int i = 0; i < m; i++, aij+=n, ei++) + (*xj) += (*aij) * (*ei); + } +#endif } /* ************************************************************************* */