diff --git a/cpp/GaussianConditional.cpp b/cpp/GaussianConditional.cpp index fa6fbb570..efad7d0d0 100644 --- a/cpp/GaussianConditional.cpp +++ b/cpp/GaussianConditional.cpp @@ -101,7 +101,7 @@ Vector GaussianConditional::solve(const VectorConfig& x) const { const Matrix& Aj = it->second; rhs -= Aj * x[j]; } - Vector result = backsubstitution(R_, rhs); + Vector result = backSubstituteUpper(R_, rhs, true); return result; } diff --git a/cpp/Matrix.cpp b/cpp/Matrix.cpp index 5d3bc47fa..98b90b895 100644 --- a/cpp/Matrix.cpp +++ b/cpp/Matrix.cpp @@ -405,33 +405,43 @@ void householder(Matrix &A, size_t k) { } /* ************************************************************************* */ -Vector backsubstitution(const Matrix& R, const Vector& b) -{ - // result vector - Vector result(b.size()); +Vector backSubstituteUpper(const Matrix& U, const Vector& b, bool unit) { + size_t m = U.size1(), n = U.size2(); +#ifndef NDEBUG + if (m!=n) + throw invalid_argument("backSubstituteUpper: U must be square"); +#endif - double tmp = 0.0; - double div = 0.0; + Vector result(n); + for (size_t i = n; i > 0; i--) { + double tmp = b(i-1); + for (size_t j = i+1; j <= n; j++) + tmp -= U(i-1,j-1) * result(j-1); + if (!unit) tmp /= U(i-1,i-1); + result(i-1) = tmp; + } - int m = R.size1(), n = R.size2(), cols=n; + return result; +} - // check each row for non zero values - for( size_t i = m ; i > 0 ; i--){ - cols--; - int j = n; +/* ************************************************************************* */ +Vector backSubstituteLower(const Matrix& L, const Vector& b, bool unit) { + size_t m = L.size1(), n = L.size2(); +#ifndef NDEBUG + if (m!=n) + throw invalid_argument("backSubstituteLower: L must be square"); +#endif - div = R(i-1, cols); + Vector result(n); + for (size_t i = 1; i <= n; i++) { + double tmp = b(i-1); + for (size_t j = 1; j < i; j++) + tmp -= L(i-1,j-1) * result(j-1); + if (!unit) tmp /= L(i-1,i-1); + result(i-1) = tmp; + } - for( int ii = i ; ii < n ; ii++ ){ - j = j - 1; - tmp = tmp + R(i-1,j) * result(j); - } - // assigne the result vector - result(i-1) = (b(i-1) - tmp) / div; - tmp = 0.0; - } - - return result; + return result; } /* ************************************************************************* */ diff --git a/cpp/Matrix.h b/cpp/Matrix.h index c2c95d4ab..072afb5c0 100644 --- a/cpp/Matrix.h +++ b/cpp/Matrix.h @@ -207,12 +207,24 @@ void householder_(Matrix& A, size_t k); void householder(Matrix& A, size_t k); /** - * backsubstitution - * @param R an upper triangular matrix + * backSubstitute U*x=b + * @param U an upper triangular matrix * @param b a RHS vector - * @return the solution of Rx=b + * @param unit, set tru if unit triangular + * @return the solution x of U*x=b + * TODO: use boost + */ +Vector backSubstituteUpper(const Matrix& U, const Vector& b, bool unit=false); + +/** + * backSubstitute L*x=b + * @param L an lower triangular matrix + * @param b a RHS vector + * @param unit, set tru if unit triangular + * @return the solution x of L*x=b + * TODO: use boost */ -Vector backsubstitution(const Matrix& R, const Vector& b); +Vector backSubstituteLower(const Matrix& L, const Vector& d, bool unit=false); /** * create a matrix by stacking other matrices diff --git a/cpp/testMatrix.cpp b/cpp/testMatrix.cpp index c18d1f0f6..283588129 100644 --- a/cpp/testMatrix.cpp +++ b/cpp/testMatrix.cpp @@ -386,38 +386,34 @@ TEST( matrix, inverse ) CHECK(assert_equal(expected, Ainv, 1e-4)); } -/* ************************************************************************* */ -/* unit test for backsubstitution */ /* ************************************************************************* */ TEST( matrix, backsubtitution ) { - // TEST ONE 2x2 matrix - Vector expectedA(2); - expectedA(0) = 3.6250 ; expectedA(1) = -0.75; + // TEST ONE 2x2 matrix U1*x=b1 + Vector expected1 = Vector_(2, 3.6250, -0.75); + Matrix U1 = Matrix_(2, 2, + 2., 3., + 0., 4.); + Vector b1 = U1*expected1; + CHECK( assert_equal(expected1 , backSubstituteUpper(U1, b1), 0.000001)); - // create a 2x2 matrix - double dataA[] = {2, 3, - 0, 4 }; - Matrix A = Matrix_(2,2,dataA); - Vector Ab(2); Ab(0) = 5; Ab(1) = -3; + // TEST TWO 3x3 matrix U2*x=b2 + Vector expected2 = Vector_(3, 5.5, -8.5, 5.); + Matrix U2 = Matrix_(3, 3, + 3., 5., 6., + 0., 2., 3., + 0., 0., 1.); + Vector b2 = U2*expected2; + CHECK( assert_equal(expected2 , backSubstituteUpper(U2, b2), 0.000001)); - CHECK( assert_equal(expectedA , backsubstitution(A, Ab), 0.000001)); - - // TEST TWO 3x3 matrix - Vector expectedB(3); - expectedB(0) = 5.5 ; expectedB(1) = -8.5; expectedB(2) = 5; - - - // create a 3x3 matrix - double dataB[] = { 3, 5, 6, - 0, 2, 3, - 0, 0, 1 }; - Matrix B = Matrix_(3,3,dataB); - - Vector Bb(3); - Bb(0) = 4; Bb(1) = -2; Bb(2) = 5; - - CHECK( assert_equal(expectedB , backsubstitution(B, Bb), 0.000001)); + // TEST THREE Lower triangular 3x3 matrix L3*x=b3 + Vector expected3 = Vector_(3, 1., 1., 1.); + Matrix L3 = Matrix_(3, 3, + 3., 0., 0., + 5., 2., 0., + 6., 3., 1.); + Vector b3 = L3*expected3; + CHECK( assert_equal(expected3 , backSubstituteLower(L3, b3), 0.000001)); } /* ************************************************************************* */