Matrix::transposeMultiplyAdd

release/4.3a0
Frank Dellaert 2010-01-31 02:53:03 +00:00
parent 9b4ff5e099
commit 4f998e5ecd
3 changed files with 43 additions and 36 deletions

View File

@ -133,20 +133,26 @@ bool assert_equal(const Matrix& expected, const Matrix& actual, double tol) {
return false; return false;
} }
/* ************************************************************************ */ /* ************************************************************************* */
/** negation */ Vector operator^(const Matrix& A, const Vector & v) {
/* ************************************************************************ */ if (A.size1()!=v.size()) throw std::invalid_argument(
/* boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
Matrix operator-() const Vector vt = trans(v);
{ Vector vtA = prod(vt,A);
size_t m = size1(),n=size2(); return trans(vtA);
Matrix M(m,n); }
for(size_t i = 0; i < m; i++)
for(size_t j = 0; j < n; j++) /* ************************************************************************* */
M(i,j) = -matrix_(i,j); void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x) {
return M; // ublas Xj += prod(trans(Aj),Ei) is terribly slow
// TODO: use BLAS
for (int j = 0; j < A.size2(); j++) {
double& Xj1 = x(j);
for (int i = 0; i < A.size1(); i++) {
Xj1 += A(i, j) * e(i);
}
}
} }
*/
/* ************************************************************************* */ /* ************************************************************************* */
Vector Vector_(const Matrix& A) Vector Vector_(const Matrix& A)

View File

@ -84,42 +84,28 @@ bool assert_equal(const Matrix& A, const Matrix& B, double tol = 1e-9);
/** /**
* overload * for matrix-vector multiplication (as BOOST does not) * overload * for matrix-vector multiplication (as BOOST does not)
*/ */
inline Vector operator*(const Matrix& A, const Vector & v) { inline Vector operator*(const Matrix& A, const Vector & v) { return prod(A,v);}
if (A.size2()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator* : A.n(%d)!=v.size(%d)") % A.size2() % v.size()));
return prod(A,v);
}
/** /**
* overload ^ for trans(A)*v * overload ^ for trans(A)*v
* We transpose the vectors for speed. * We transpose the vectors for speed.
*/ */
inline Vector operator^(const Matrix& A, const Vector & v) { Vector operator^(const Matrix& A, const Vector & v);
if (A.size1()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator^ : A.m(%d)!=v.size(%d)") % A.size1() % v.size())); /**
Vector vt = trans(v); * BLAS Level-2 style x <- x + A'*e
Vector vtA = prod(vt,A); */
return trans(vtA); void transposeMultiplyAdd(const Matrix& A, const Vector& e, Vector& x);
}
/** /**
* overload * for vector*matrix multiplication (as BOOST does not) * overload * for vector*matrix multiplication (as BOOST does not)
*/ */
inline Vector operator*(const Vector & v, const Matrix& A) { inline Vector operator*(const Vector & v, const Matrix& A) { return prod(v,A);}
if (A.size1()!=v.size()) throw std::invalid_argument(
boost::str(boost::format("Matrix operator* : A.m(%d)!=v.size(%d)") % A.size1() % v.size()));
return prod(v,A);
}
/** /**
* overload * for matrix multiplication (as BOOST does not) * overload * for matrix multiplication (as BOOST does not)
*/ */
inline Matrix operator*(const Matrix& A, const Matrix& B) { inline Matrix operator*(const Matrix& A, const Matrix& B) { return prod(A,B);}
// richard: boost already does this check in debug mode I think
// if (A.size2()!=B.size1()) throw std::invalid_argument(
// boost::str(boost::format("Matrix operator* : A.n(%d)!=B.m(%d)") % A.size2() % B.size1()));
return prod(A,B);
}
/** /**
* convert to column vector, column order !!! * convert to column vector, column order !!!

View File

@ -772,6 +772,21 @@ TEST( matrix, square_root_positive )
CHECK(assert_equal(cov, prod(trans(actual),actual))); CHECK(assert_equal(cov, prod(trans(actual),actual)));
} }
/* ************************************************************************* */
TEST( matrix, transposeMultiplyAdd )
{
Matrix A = Matrix_(3,4,
4., 0., 0., 1.,
0., 4., 0., 2.,
0., 0., 1., 3.
);
Vector x = Vector_(4, 1., 2., 3., 4.), e = Vector_(3, 5., 6., 7.),
expected = x + prod(trans(A), e);
transposeMultiplyAdd(A,e,x);
CHECK(assert_equal(expected, x));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); } int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
/* ************************************************************************* */ /* ************************************************************************* */