Fix inverse_square_root, add cholesky decomposition options

release/4.3a0
justinca 2010-01-22 22:28:03 +00:00
parent e88ae4a944
commit 16c55975c1
3 changed files with 123 additions and 0 deletions

View File

@ -18,6 +18,10 @@
#include <boost/foreach.hpp>
#include <boost/numeric/ublas/lu.hpp>
#include <boost/numeric/ublas/io.hpp>
#include <boost/numeric/ublas/triangular.hpp>
#include <boost/numeric/ublas/symmetric.hpp>
#include <boost/numeric/ublas/matrix_proxy.hpp>
#include <boost/numeric/ublas/vector_proxy.hpp>
#include <boost/tuple/tuple.hpp>
#include "Matrix.h"
@ -714,6 +718,55 @@ double** createNRC(Matrix& A) {
return a;
}
/* ******************************************
*
* Modified from Justin's codebase
*
* Idea came from other public domain code. Takes a S.P.D. matrix
* and computes the LL^t decomposition. returns L, which is lower
* triangular. Note this is the opposite convention from Matlab,
* which calculates Q'Q where Q is upper triangular.
*
* ******************************************/
namespace BNU = boost::numeric::ublas;
Matrix cholesky(const Matrix& A)
{
assert(A.size1() == A.size2());
Matrix L = zeros(A.size1(), A.size1());
for (size_t i = 0 ; i < A.size1(); i++) {
double p = A(i,i) - BNU::inner_prod( BNU::project( BNU::row(L, i), BNU::range(0, i) ),
BNU::project( BNU::row(L, i), BNU::range(0, i) ) );
assert(p > 0); // Rank failure
double l_i_i = sqrt(p);
L(i,i) = l_i_i;
BNU::matrix_column<Matrix> l_i(L, i);
project( l_i, BNU::range(i+1, A.size1()) )
= ( BNU::project( BNU::column(A, i), BNU::range(i+1, A.size1()) )
- BNU::prod( BNU::project(L, BNU::range(i+1, A.size1()), BNU::range(0, i)),
BNU::project(BNU::row(L, i), BNU::range(0, i) ) ) ) / l_i_i;
}
return L;
}
/*
* This is not ultra efficient, but not terrible, either.
*/
Matrix cholesky_inverse(const Matrix &A)
{
Matrix L = cholesky(A);
Matrix inv(boost::numeric::ublas::identity_matrix<double>(A.size1()));
inplace_solve (L, inv, BNU::lower_tag ());
return BNU::prod(trans(inv), inv);
}
/* ************************************************************************* */
/** SVD */
/* ************************************************************************* */
@ -749,6 +802,7 @@ void svd(const Matrix& A, Matrix& U, Vector& s, Matrix& V) {
svd(U,s,V); // call in-place version
}
#if 0
/* ************************************************************************* */
// TODO, would be faster with Cholesky
Matrix inverse_square_root(const Matrix& A) {
@ -766,6 +820,23 @@ Matrix inverse_square_root(const Matrix& A) {
for(size_t i = 0; i<m; i++) S(i) = - pow(S(i),-0.5);
return vector_scale(S, V); // V*S;
}
#endif
/* ************************************************************************* */
// New, improved, with 100% more Cholesky goodness!
//
// Semantics:
// if B = inverse_square_root(A), then all of the following are true:
// inv(B) * inv(B)' == A
// inv(B' * B) == A
Matrix inverse_square_root(const Matrix& A) {
Matrix L = cholesky(A);
Matrix inv(boost::numeric::ublas::identity_matrix<double>(A.size1()));
inplace_solve (L, inv, BNU::lower_tag ());
return inv;
}
/* ************************************************************************* */
Matrix square_root_positive(const Matrix& A) {

View File

@ -302,6 +302,14 @@ Matrix inverse_square_root(const Matrix& A);
/** Use SVD to calculate the positive square root of a matrix */
Matrix square_root_positive(const Matrix& A);
/** Calculate the LL^t decomposition of a S.P.D matrix */
Matrix cholesky(const Matrix& A);
/** Return the inverse of a S.P.D. matrix. Inversion is done via Cholesky decomposition. */
Matrix cholesky_inverse(const Matrix &A);
// macro for unit tests
#define EQUALITY(expected,actual)\
{ if (!assert_equal(expected,actual)) \

View File

@ -705,6 +705,50 @@ TEST( matrix, inverse_square_root )
EQUALITY(expected,actual);
EQUALITY(measurement_covariance,inverse(actual*actual));
// Randomly generated test. This test really requires inverse to
// be working well; if it's not, there's the possibility of a
// bug in inverse masking a bug in this routine since we
// use the same inverse routing inside inverse_square_root()
// as we use here to check it.
Matrix M = Matrix_(5, 5,
0.0785892, 0.0137923, -0.0142219, -0.0171880, 0.0028726,
0.0137923, 0.0908911, 0.0020775, -0.0101952, 0.0175868,
-0.0142219, 0.0020775, 0.0973051, 0.0054906, 0.0047064,
-0.0171880, -0.0101952, 0.0054906, 0.0892453, -0.0059468,
0.0028726, 0.0175868, 0.0047064, -0.0059468, 0.0816517);
expected = Matrix_(5, 5,
3.567126953241796, 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.000000000000000,
-0.590030436566913, 3.362022286742925, 0.000000000000000, 0.000000000000000, 0.000000000000000,
0.618207860252376, -0.168166020746503, 3.253086082942785, 0.000000000000000, 0.000000000000000,
0.683045380655496, 0.283773848115276, -0.099969232183396, 3.433537147891568, 0.000000000000000,
-0.006740136923185, -0.669325697387650, -0.169716689114923, 0.171493059476284, 3.583921085468937);
EQUALITY(expected, inverse_square_root(M));
}
/* *********************************************************************** */
// M was generated as the covariance of a set of random numbers. L that
// we are checking against was generated via chol(M)' on octave
TEST( matrix, cholesky )
{
Matrix M = Matrix_(5, 5,
0.0874197, -0.0030860, 0.0116969, 0.0081463, 0.0048741,
-0.0030860, 0.0872727, 0.0183073, 0.0125325, -0.0037363,
0.0116969, 0.0183073, 0.0966217, 0.0103894, -0.0021113,
0.0081463, 0.0125325, 0.0103894, 0.0747324, 0.0036415,
0.0048741, -0.0037363, -0.0021113, 0.0036415, 0.0909464);
Matrix expected = Matrix_(5, 5,
0.295668226226627, 0.000000000000000, 0.000000000000000, 0.000000000000000, 0.000000000000000,
-0.010437374483502, 0.295235094820875, 0.000000000000000, 0.000000000000000, 0.000000000000000,
0.039560896175007, 0.063407813693827, 0.301721866387571, 0.000000000000000, 0.000000000000000,
0.027552165831157, 0.043423266737274, 0.021695600982708, 0.267613525371710, 0.000000000000000,
0.016485031422565, -0.012072546984405, -0.006621889326331, 0.014405837566082, 0.300462176944247);
EQUALITY(expected, cholesky(M));
}
/* ************************************************************************* */