From 4604cbce050d0f97a692427d181150ae1c2a5b58 Mon Sep 17 00:00:00 2001 From: Manohar Paluri Date: Sat, 27 Feb 2010 18:23:34 +0000 Subject: [PATCH] svd now handles m < n cases. Added unit tests to verify this. But svd in place will throw an exception for such cases. --- cpp/Matrix.cpp | 14 +++++-- cpp/Matrix.h | 17 +++++++-- cpp/testMatrix.cpp | 91 ++++++++++++++++++++++++++++++++++------------ 3 files changed, 92 insertions(+), 30 deletions(-) diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index cabcb8210..95dc1f528 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -966,6 +966,8 @@ Matrix cholesky_inverse(const Matrix &A) void svd(Matrix& A, Vector& s, Matrix& V, bool sort) { const size_t m=A.size1(), n=A.size2(); + if( m < n ) + throw invalid_argument("in-place svd calls NRC which needs matrix A with m>n"); double * q = new double[n]; // singular values @@ -984,13 +986,19 @@ void svd(Matrix& A, Vector& s, Matrix& V, bool sort) { delete[] v; delete[] q; //switched to array delete delete[] u; - } /* ************************************************************************* */ void svd(const Matrix& A, Matrix& U, Vector& s, Matrix& V, bool sort) { - U = A; // copy - svd(U,s,V,sort); // call in-place version + const size_t m=A.size1(), n=A.size2(); + if( m < n ) { + V = trans(A); + svd(V,s,U,sort); // A'=V*diag(s)*U' + } + else{ + U = A; // copy + svd(U,s,V,sort); // call in-place version + } } #if 0 diff --git a/cpp/Matrix.h b/cpp/Matrix.h index 3852898b1..7e8705901 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -326,14 +326,23 @@ inline Matrix skewSymmetric(const Vector& w) { return skewSymmetric(w(0),w(1),w( /** * SVD computes economy SVD A=U*S*V' * @param A an m*n matrix - * @param U output argument: m*n matrix - * @param S output argument: n-dim vector of singular values, sorted by default, pass false as last argument to avoid sorting!!! - * @param V output argument: n*n matrix + * @param U output argument: rotation matrix + * @param S output argument: vector of singular values, sorted by default, pass false as last argument to avoid sorting!!! + * @param V output argument: rotation matrix * @param sort boolean flag to sort singular values and V + * if m > n then U*S*V' = (m*n)*(n*n)*(n*n) (the m-n columns of U are of no use) + * if m < n then U*S*V' = (m*m)*(m*m)*(m*n) (the n-m columns of V are of no use) */ void svd(const Matrix& A, Matrix& U, Vector& S, Matrix& V, bool sort=true); -// in-place version +/* + * In place SVD, will throw an exception when m < n + * @param A an m*n matrix is modified to contain U + * @param S output argument: vector of singular values, sorted by default, pass false as last argument to avoid sorting!!! + * @param V output argument: rotation matrix + * @param sort boolean flag to sort singular values and V + * if m > n then U*S*V' = (m*n)*(n*n)*(n*n) (the m-n columns of U are of no use) + */ void svd(Matrix& A, Vector& S, Matrix& V, bool sort=true); /** Use SVD to calculate inverse square root of a matrix */ diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index 3f6a5801e..7902183cd 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -691,7 +691,7 @@ TEST( matrix, row_major_access ) } /* ************************************************************************* */ -TEST( matrix, svd ) +TEST( matrix, svd1 ) { double data[] = { 2, 1, 0 }; Vector v(3); @@ -707,34 +707,79 @@ TEST( matrix, svd ) } /* ************************************************************************* */ -TEST( matrix, svdordering ) -{ - /// Homography matrix for points - //Point2h(0, 0, 1), Point2h(4, 5, 1); - //Point2h(1, 0, 1), Point2h(5, 5, 1); - //Point2h(1, 1, 1), Point2h(5, 6, 1); - //Point2h(0, 1, 1), Point2h(4, 6, 1); - double data[] = {0,0,0,-4,-5,-1,0,0,0, - 4,5,1,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0, - 0,0,0,-5,-5,-1,0,0,0, - 5,5,1,0,0,0,-5,-5,-1, - 0,0,0,5,5,1,0,0,0, - 0,0,0,-5,-6,-1,5,6,1, - 5,6,1,0,0,0,-5,-6,-1, - -5,-6,-1,5,6,1,0,0,0, - 0,0,0,-4,-6,-1,4,6,1, - 4,6,1,0,0,0,0,0,0, - -4,-6,-1,0,0,0,0,0,0}; +/// Sample A matrix for SVD +static double sampleData[] ={0,-2, 0,0, 3,0}; +static Matrix sampleA = Matrix_(3, 2, sampleData); +static Matrix sampleAt = trans(sampleA); - Matrix A = Matrix_(12, 9, data); +/* ************************************************************************* */ +TEST( matrix, svd2 ) +{ + Matrix U, V; + Vector s; + + Matrix expectedU = Matrix_(3, 2, 0.,1.,0.,0.,-1.,0.); + Vector expected_s = Vector_(2, 3.,2.); + Matrix expectedV = Matrix_(2, 2, -1.,0.,0.,-1.); + + svd(sampleA, U, s, V); + + EQUALITY(expectedU,U); + CHECK(equal_with_abs_tol(expected_s,s,1e-9)); + EQUALITY(expectedV,V); +} + +/* ************************************************************************* */ +TEST( matrix, svd3 ) +{ + Matrix U, V; + Vector s; + + Matrix expectedU = Matrix_(2, 2, -1.,0.,0.,-1.); + Vector expected_s = Vector_(2, 3.0,2.0); + Matrix expectedV = Matrix_(3, 2, 0.,1.,0.,0.,-1.,0.); + + svd(sampleAt, U, s, V); + Matrix S = diag(s); + Matrix t = prod(U,S); + Matrix Vt = trans(V); + + EQUALITY(sampleAt, prod(t,Vt)); + EQUALITY(expectedU,U); + CHECK(equal_with_abs_tol(expected_s,s,1e-9)); + EQUALITY(expectedV,V); +} + +/* ************************************************************************* */ +/// Homography matrix for points +//Point2h(0, 0, 1), Point2h(4, 5, 1); +//Point2h(1, 0, 1), Point2h(5, 5, 1); +//Point2h(1, 1, 1), Point2h(5, 6, 1); +//Point2h(0, 1, 1), Point2h(4, 6, 1); +static double homography_data[] = {0,0,0,-4,-5,-1,0,0,0, + 4,5,1,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0, + 0,0,0,-5,-5,-1,0,0,0, + 5,5,1,0,0,0,-5,-5,-1, + 0,0,0,5,5,1,0,0,0, + 0,0,0,-5,-6,-1,5,6,1, + 5,6,1,0,0,0,-5,-6,-1, + -5,-6,-1,5,6,1,0,0,0, + 0,0,0,-4,-6,-1,4,6,1, + 4,6,1,0,0,0,0,0,0, + -4,-6,-1,0,0,0,0,0,0}; +static Matrix homographyA = Matrix_(12, 9, homography_data); + +/* ************************************************************************* */ +TEST( matrix, svd_sort ) +{ Matrix U1, U2, V1, V2; Vector s1, s2; - svd(A, U1, s1, V1); + svd(homographyA, U1, s1, V1); for(int i = 0 ; i < 8 ; i++) CHECK(s1[i]>=s1[i+1]); // Check if singular values are sorted - svd(A, U2, s2, V2, false); + svd(homographyA, U2, s2, V2, false); CHECK(s1[8]==s2[7]); // Check if swapping is done CHECK(s1[7]==s2[8]); Vector v17 = column_(V1, 7);