Matrix::transposeMultiplyAdd

release/4.3a0
Frank Dellaert 2010-01-31 02:53:03 +00:00
parent 9b4ff5e099
commit 4f998e5ecd
3 changed files with 43 additions and 36 deletions

View File

@ -133,20 +133,26 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) {
return false;
}
/* ************************************************************************ */
/** negation */
/* ************************************************************************ */
/*
Matrix operator-() const
{
size_t m = size1(),n=size2();
Matrix M(m,n);
for(size_t i = 0; i < m; i++)
for(size_t j = 0; j < n; j++)
M(i,j) = -matrix_(i,j);
return M;
/* ************************************************************************* */
Vector operator^(const Matrix& A, const Vector & v) {
if (A.size1()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
Vector vt = trans(v);
Vector vtA = prod(vt,A);
return trans(vtA);
}
/* ************************************************************************* */
void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) {
// ublas Xj += prod(trans(Aj),Ei) is terribly slow
// TODO: use BLAS
for (int j = 0; j < A.size2(); j++) {
double& Xj1 = x(j);
for (int i = 0; i < A.size1(); i++) {
Xj1 += A(i, j) * e(i);
}
}
}
*/
/* ************************************************************************* */
Vector Vector_(const Matrix& A)

View File

@ -84,42 +84,28 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9);
/**
* overload * for matrix-vector multiplication (as BOOST does not)
*/
inline Vector operator*(const Matrix& A, const Vector & v) {
if (A.size2()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator* : A.n(%d)!=v.size(%d)") % A.size2() % v.size()));
return prod(A,v);
}
inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);}
/**
* overload ^ for trans(A)*v
* We transpose the vectors for speed.
*/
inline Vector operator^(const Matrix& A, const Vector & v) {
if (A.size1()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
Vector vt = trans(v);
Vector vtA = prod(vt,A);
return trans(vtA);
}
Vector operator^(const Matrix& A, const Vector & v);
/**
* BLAS Level-2 style x <- x + A'*e
*/
void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x);
/**
* overload * for vector*matrix multiplication (as BOOST does not)
*/
inline Vector operator*(const Vector & v, const Matrix& A) {
if (A.size1()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator* : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
return prod(v,A);
}
inline Vector operator*(const Vector & v, const Matrix& A) { return prod(v,A);}
/**
* overload * for matrix multiplication (as BOOST does not)
*/
inline Matrix operator*(const Matrix& A, const Matrix& B) {
// richard: boost already does this check in debug mode I think
// if (A.size2()!=B.size1()) throw std::invalid_argument(
// boost::str(boost::format("Matrix operator* : A.n(%d)!=B.m(%d)") % A.size2() % B.size1()));
return prod(A,B);
}
inline Matrix operator*(const Matrix& A, const Matrix& B) { return prod(A,B);}
/**
* convert to column vector, column order !!!

View File

@ -772,6 +772,21 @@ TEST( matrix, square_root_positive )
CHECK(assert_equal(cov, prod(trans(actual),actual)));
}
/* ************************************************************************* */
TEST( matrix, transposeMultiplyAdd )
{
Matrix A = Matrix_(3,4,
4., 0., 0., 1.,
0., 4., 0., 2.,
0., 0., 1., 3.
);
Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.),
expected = x + prod(trans(A), e);
transposeMultiplyAdd(A,e,x);
CHECK(assert_equal(expected, x));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */