diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 562b8c385..528870324 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -138,14 +138,20 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) { /* ************************************************************************* */ void multiplyAdd(double alpha, const Matrix& A, const Vector& x, Vector& e) { +#ifdef GSL + gsl_vector_const_view xg = gsl_vector_const_view_array(x.data().begin(), x.size()); + gsl_vector_view eg = gsl_vector_view_array(e.data().begin(), e.size()); + gsl_matrix_const_view Ag = gsl_matrix_const_view_array(A.data().begin(), A.size1(), A.size2()); + gsl_blas_dgemv (CblasNoTrans, alpha, &(Ag.matrix), &(xg.vector), 1.0, &(eg.vector)); +#else // ublas e += prod(A,x) is terribly slow - // TODO: use BLAS for (int i = 0; i < A.size1(); i++) { - double& ei = e(i); - for (int j = 0; j < A.size2(); j++) { - ei += alpha * A(i, j) * x(j); - } + double& ei = e(i); + for (int j = 0; j < A.size2(); j++) { + ei += alpha * A(i, j) * x(j); } + } +#endif } /* ************************************************************************* */ @@ -159,6 +165,12 @@ Vector operator^(const Matrix& A, const Vector & v) { /* ************************************************************************* */ void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector& x) { +#ifdef GSL + gsl_vector_const_view eg = gsl_vector_const_view_array(e.data().begin(), e.size()); + gsl_vector_view xg = gsl_vector_view_array(x.data().begin(), x.size()); + gsl_matrix_const_view Ag = gsl_matrix_const_view_array(A.data().begin(), A.size1(), A.size2()); + gsl_blas_dgemv (CblasTrans, alpha, &(Ag.matrix), &(eg.vector), 1.0, &(xg.vector)); +#else // ublas x += prod(trans(A),e) is terribly slow // TODO: use BLAS for (int j = 0; j < A.size2(); j++) { @@ -167,6 +179,7 @@ void transposeMultiplyAdd(double alpha, const Matrix& A, const Vector& e, Vector xj += alpha * A(i, j) * e(i); } } +#endif } /* ************************************************************************* */