From 16c55975c116312efa30f6b63021c54ad1ddfa2f Mon Sep 17 00:00:00 2001 From: justinca Date: Fri, 22 Jan 2010 22:28:03 +0000 Subject: [PATCH] Fix inverse_square_root, add cholesky decomposition options --- cpp/Matrix.cpp | 71 ++++++++++++++++++++++++++++++++++++++++++++++ cpp/Matrix.h | 8 ++++++ cpp/testMatrix.cpp | 44 ++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+) diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 99b454512..897079267 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -18,6 +18,10 @@ #include #include #include +#include +#include +#include +#include #include #include "Matrix.h" @@ -714,6 +718,55 @@ double** createNRC(Matrix& A) { return a; } + + +/* ****************************************** + * + * Modified from Justin's codebase + * + * Idea came from other public domain code. Takes a S.P.D. matrix + * and computes the LL^t decomposition. returns L, which is lower + * triangular. Note this is the opposite convention from Matlab, + * which calculates Q'Q where Q is upper triangular. + * + * ******************************************/ + +namespace BNU = boost::numeric::ublas; + + +Matrix cholesky(const Matrix& A) +{ + assert(A.size1() == A.size2()); + Matrix L = zeros(A.size1(), A.size1()); + + for (size_t i = 0 ; i < A.size1(); i++) { + double p = A(i,i) - BNU::inner_prod( BNU::project( BNU::row(L, i), BNU::range(0, i) ), + BNU::project( BNU::row(L, i), BNU::range(0, i) ) ); + assert(p > 0); // Rank failure + double l_i_i = sqrt(p); + L(i,i) = l_i_i; + + BNU::matrix_column l_i(L, i); + project( l_i, BNU::range(i+1, A.size1()) ) + = ( BNU::project( BNU::column(A, i), BNU::range(i+1, A.size1()) ) + - BNU::prod( BNU::project(L, BNU::range(i+1, A.size1()), BNU::range(0, i)), + BNU::project(BNU::row(L, i), BNU::range(0, i) ) ) ) / l_i_i; + } + return L; +} + +/* + * This is not ultra efficient, but not terrible, either. + */ +Matrix cholesky_inverse(const Matrix &A) +{ + Matrix L = cholesky(A); + Matrix inv(boost::numeric::ublas::identity_matrix(A.size1())); + inplace_solve (L, inv, BNU::lower_tag ()); + return BNU::prod(trans(inv), inv); +} + + /* ************************************************************************* */ /** SVD */ /* ************************************************************************* */ @@ -749,6 +802,7 @@ void svd(const Matrix& A, Matrix& U, Vector& s, Matrix& V) { svd(U,s,V); // call in-place version } +#if 0 /* ************************************************************************* */ // TODO, would be faster with Cholesky Matrix inverse_square_root(const Matrix& A) { @@ -766,6 +820,23 @@ Matrix inverse_square_root(const Matrix& A) { for(size_t i = 0; i(A.size1())); + inplace_solve (L, inv, BNU::lower_tag ()); + return inv; +} + + /* ************************************************************************* */ Matrix square_root_positive(const Matrix& A) { diff --git a/cpp/Matrix.h b/cpp/Matrix.h index 5cdbd6e85..d8cd2942d 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -302,6 +302,14 @@ Matrix inverse_square_root(const Matrix& A); /** Use SVD to calculate the positive square root of a matrix */ Matrix square_root_positive(const Matrix& A); +/** Calculate the LL^t decomposition of a S.P.D matrix */ +Matrix cholesky(const Matrix& A); + +/** Return the inverse of a S.P.D. matrix. Inversion is done via Cholesky decomposition. */ +Matrix cholesky_inverse(const Matrix &A); + + + // macro for unit tests #define EQUALITY(expected,actual)\ { if (!assert_equal(expected,actual)) \ diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index 785c99584..041824c5f 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -705,6 +705,50 @@ TEST( matrix, inverse_square_root ) EQUALITY(expected,actual); EQUALITY(measurement_covariance,inverse(actual*actual)); + + // Randomly generated test. This test really requires inverse to + // be working well; if it's not, there's the possibility of a + // bug in inverse masking a bug in this routine since we + // use the same inverse routing inside inverse_square_root() + // as we use here to check it. + + Matrix M = Matrix_(5, 5, + 0.0785892, 0.0137923, -0.0142219, -0.0171880, 0.0028726, + 0.0137923, 0.0908911, 0.0020775, -0.0101952, 0.0175868, + -0.0142219, 0.0020775, 0.0973051, 0.0054906, 0.0047064, + -0.0171880, -0.0101952, 0.0054906, 0.0892453, -0.0059468, + 0.0028726, 0.0175868, 0.0047064, -0.0059468, 0.0816517); + + expected = Matrix_(5, 5, + 3.567126953241796, 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.000000000000000, + -0.590030436566913, 3.362022286742925, 0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.618207860252376, -0.168166020746503, 3.253086082942785, 0.000000000000000, 0.000000000000000, + 0.683045380655496, 0.283773848115276, -0.099969232183396, 3.433537147891568, 0.000000000000000, + -0.006740136923185, -0.669325697387650, -0.169716689114923, 0.171493059476284, 3.583921085468937); + EQUALITY(expected, inverse_square_root(M)); + +} + +/* *********************************************************************** */ +// M was generated as the covariance of a set of random numbers. L that +// we are checking against was generated via chol(M)' on octave +TEST( matrix, cholesky ) +{ + Matrix M = Matrix_(5, 5, + 0.0874197, -0.0030860, 0.0116969, 0.0081463, 0.0048741, + -0.0030860, 0.0872727, 0.0183073, 0.0125325, -0.0037363, + 0.0116969, 0.0183073, 0.0966217, 0.0103894, -0.0021113, + 0.0081463, 0.0125325, 0.0103894, 0.0747324, 0.0036415, + 0.0048741, -0.0037363, -0.0021113, 0.0036415, 0.0909464); + + Matrix expected = Matrix_(5, 5, + 0.295668226226627, 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.000000000000000, + -0.010437374483502, 0.295235094820875, 0.000000000000000, 0.000000000000000, 0.000000000000000, + 0.039560896175007, 0.063407813693827, 0.301721866387571, 0.000000000000000, 0.000000000000000, + 0.027552165831157, 0.043423266737274, 0.021695600982708, 0.267613525371710, 0.000000000000000, + 0.016485031422565, -0.012072546984405, -0.006621889326331, 0.014405837566082, 0.300462176944247); + + EQUALITY(expected, cholesky(M)); } /* ************************************************************************* */