Matrix::transposeMultiplyAdd
parent
9b4ff5e099
commit
4f998e5ecd
|
|
@ -133,20 +133,26 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
/** negation */
|
||||
/* ************************************************************************ */
|
||||
/*
|
||||
Matrix operator-() const
|
||||
{
|
||||
size_t m = size1(),n=size2();
|
||||
Matrix M(m,n);
|
||||
for(size_t i = 0; i < m; i++)
|
||||
for(size_t j = 0; j < n; j++)
|
||||
M(i,j) = -matrix_(i,j);
|
||||
return M;
|
||||
/* ************************************************************************* */
|
||||
Vector operator^(const Matrix& A, const Vector & v) {
|
||||
if (A.size1()!=v.size()) throw std::invalid_argument(
|
||||
boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
|
||||
Vector vt = trans(v);
|
||||
Vector vtA = prod(vt,A);
|
||||
return trans(vtA);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) {
|
||||
// ublas Xj += prod(trans(Aj),Ei) is terribly slow
|
||||
// TODO: use BLAS
|
||||
for (int j = 0; j < A.size2(); j++) {
|
||||
double& Xj1 = x(j);
|
||||
for (int i = 0; i < A.size1(); i++) {
|
||||
Xj1 += A(i, j) * e(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/* ************************************************************************* */
|
||||
Vector Vector_(const Matrix& A)
|
||||
|
|
|
|||
32
cpp/Matrix.h
32
cpp/Matrix.h
|
|
@ -84,42 +84,28 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9);
|
|||
/**
|
||||
* overload * for matrix-vector multiplication (as BOOST does not)
|
||||
*/
|
||||
inline Vector operator*(const Matrix& A, const Vector & v) {
|
||||
if (A.size2()!=v.size()) throw std::invalid_argument(
|
||||
boost::str(boost::format("Matrix operator* : A.n(%d)!=v.size(%d)") % A.size2() % v.size()));
|
||||
return prod(A,v);
|
||||
}
|
||||
inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);}
|
||||
|
||||
/**
|
||||
* overload ^ for trans(A)*v
|
||||
* We transpose the vectors for speed.
|
||||
*/
|
||||
inline Vector operator^(const Matrix& A, const Vector & v) {
|
||||
if (A.size1()!=v.size()) throw std::invalid_argument(
|
||||
boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
|
||||
Vector vt = trans(v);
|
||||
Vector vtA = prod(vt,A);
|
||||
return trans(vtA);
|
||||
}
|
||||
Vector operator^(const Matrix& A, const Vector & v);
|
||||
|
||||
/**
|
||||
* BLAS Level-2 style x <- x + A'*e
|
||||
*/
|
||||
void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x);
|
||||
|
||||
/**
|
||||
* overload * for vector*matrix multiplication (as BOOST does not)
|
||||
*/
|
||||
inline Vector operator*(const Vector & v, const Matrix& A) {
|
||||
if (A.size1()!=v.size()) throw std::invalid_argument(
|
||||
boost::str(boost::format("Matrix operator* : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
|
||||
return prod(v,A);
|
||||
}
|
||||
inline Vector operator*(const Vector & v, const Matrix& A) { return prod(v,A);}
|
||||
|
||||
/**
|
||||
* overload * for matrix multiplication (as BOOST does not)
|
||||
*/
|
||||
inline Matrix operator*(const Matrix& A, const Matrix& B) {
|
||||
// richard: boost already does this check in debug mode I think
|
||||
// if (A.size2()!=B.size1()) throw std::invalid_argument(
|
||||
// boost::str(boost::format("Matrix operator* : A.n(%d)!=B.m(%d)") % A.size2() % B.size1()));
|
||||
return prod(A,B);
|
||||
}
|
||||
inline Matrix operator*(const Matrix& A, const Matrix& B) { return prod(A,B);}
|
||||
|
||||
/**
|
||||
* convert to column vector, column order !!!
|
||||
|
|
|
|||
|
|
@ -772,6 +772,21 @@ TEST( matrix, square_root_positive )
|
|||
CHECK(assert_equal(cov, prod(trans(actual),actual)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( matrix, transposeMultiplyAdd )
|
||||
{
|
||||
Matrix A = Matrix_(3,4,
|
||||
4., 0., 0., 1.,
|
||||
0., 4., 0., 2.,
|
||||
0., 0., 1., 3.
|
||||
);
|
||||
Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.),
|
||||
expected = x + prod(trans(A), e);
|
||||
|
||||
transposeMultiplyAdd(A,e,x);
|
||||
CHECK(assert_equal(expected, x));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
Loading…
Reference in New Issue