Using struct specialization to select MKL Householder QR in Eigen
parent
513b5fd8d6
commit
9b8004d780
|
|
@ -251,56 +251,60 @@ void householder_qr_inplace_unblocked(MatrixQR& mat, HCoeffs& hCoeffs, typename
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal */
|
/** \internal */
|
||||||
template<typename MatrixQR, typename HCoeffs>
|
template<typename MatrixQR, typename HCoeffs, typename Scalar = MatrixQR::Scalar>
|
||||||
void householder_qr_inplace_blocked(MatrixQR& mat, HCoeffs& hCoeffs,
|
struct householder_qr_inplace_blocked
|
||||||
typename MatrixQR::Index maxBlockSize=32,
|
|
||||||
typename MatrixQR::Scalar* tempData = 0)
|
|
||||||
{
|
{
|
||||||
typedef typename MatrixQR::Index Index;
|
// This is specialized for MKL-supported Scalar types in HouseholderQR_MKL.h
|
||||||
typedef typename MatrixQR::Scalar Scalar;
|
static void run(MatrixQR& mat, HCoeffs& hCoeffs,
|
||||||
typedef Block<MatrixQR,Dynamic,Dynamic> BlockType;
|
typename MatrixQR::Index maxBlockSize=32,
|
||||||
|
typename MatrixQR::Scalar* tempData = 0)
|
||||||
Index rows = mat.rows();
|
|
||||||
Index cols = mat.cols();
|
|
||||||
Index size = (std::min)(rows, cols);
|
|
||||||
|
|
||||||
typedef Matrix<Scalar,Dynamic,1,ColMajor,MatrixQR::MaxColsAtCompileTime,1> TempType;
|
|
||||||
TempType tempVector;
|
|
||||||
if(tempData==0)
|
|
||||||
{
|
{
|
||||||
tempVector.resize(cols);
|
typedef typename MatrixQR::Index Index;
|
||||||
tempData = tempVector.data();
|
typedef typename MatrixQR::Scalar Scalar;
|
||||||
}
|
typedef Block<MatrixQR, Dynamic, Dynamic> BlockType;
|
||||||
|
|
||||||
Index blockSize = (std::min)(maxBlockSize,size);
|
Index rows = mat.rows();
|
||||||
|
Index cols = mat.cols();
|
||||||
|
Index size = (std::min)(rows, cols);
|
||||||
|
|
||||||
Index k = 0;
|
typedef Matrix<Scalar, Dynamic, 1, ColMajor, MatrixQR::MaxColsAtCompileTime, 1> TempType;
|
||||||
for (k = 0; k < size; k += blockSize)
|
TempType tempVector;
|
||||||
{
|
if(tempData==0)
|
||||||
Index bs = (std::min)(size-k,blockSize); // actual size of the block
|
|
||||||
Index tcols = cols - k - bs; // trailing columns
|
|
||||||
Index brows = rows-k; // rows of the block
|
|
||||||
|
|
||||||
// partition the matrix:
|
|
||||||
// A00 | A01 | A02
|
|
||||||
// mat = A10 | A11 | A12
|
|
||||||
// A20 | A21 | A22
|
|
||||||
// and performs the qr dec of [A11^T A12^T]^T
|
|
||||||
// and update [A21^T A22^T]^T using level 3 operations.
|
|
||||||
// Finally, the algorithm continue on A22
|
|
||||||
|
|
||||||
BlockType A11_21 = mat.block(k,k,brows,bs);
|
|
||||||
Block<HCoeffs,Dynamic,1> hCoeffsSegment = hCoeffs.segment(k,bs);
|
|
||||||
|
|
||||||
householder_qr_inplace_unblocked(A11_21, hCoeffsSegment, tempData);
|
|
||||||
|
|
||||||
if(tcols)
|
|
||||||
{
|
{
|
||||||
BlockType A21_22 = mat.block(k,k+bs,brows,tcols);
|
tempVector.resize(cols);
|
||||||
apply_block_householder_on_the_left(A21_22,A11_21,hCoeffsSegment.adjoint());
|
tempData = tempVector.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
Index blockSize = (std::min)(maxBlockSize, size);
|
||||||
|
|
||||||
|
Index k = 0;
|
||||||
|
for(k = 0; k < size; k += blockSize)
|
||||||
|
{
|
||||||
|
Index bs = (std::min)(size-k, blockSize); // actual size of the block
|
||||||
|
Index tcols = cols - k - bs; // trailing columns
|
||||||
|
Index brows = rows-k; // rows of the block
|
||||||
|
|
||||||
|
// partition the matrix:
|
||||||
|
// A00 | A01 | A02
|
||||||
|
// mat = A10 | A11 | A12
|
||||||
|
// A20 | A21 | A22
|
||||||
|
// and performs the qr dec of [A11^T A12^T]^T
|
||||||
|
// and update [A21^T A22^T]^T using level 3 operations.
|
||||||
|
// Finally, the algorithm continue on A22
|
||||||
|
|
||||||
|
BlockType A11_21 = mat.block(k, k, brows, bs);
|
||||||
|
Block<HCoeffs, Dynamic, 1> hCoeffsSegment = hCoeffs.segment(k, bs);
|
||||||
|
|
||||||
|
householder_qr_inplace_unblocked(A11_21, hCoeffsSegment, tempData);
|
||||||
|
|
||||||
|
if(tcols)
|
||||||
|
{
|
||||||
|
BlockType A21_22 = mat.block(k, k+bs, brows, tcols);
|
||||||
|
apply_block_householder_on_the_left(A21_22, A11_21, hCoeffsSegment.adjoint());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
template<typename _MatrixType, typename Rhs>
|
template<typename _MatrixType, typename Rhs>
|
||||||
struct solve_retval<HouseholderQR<_MatrixType>, Rhs>
|
struct solve_retval<HouseholderQR<_MatrixType>, Rhs>
|
||||||
|
|
@ -352,7 +356,7 @@ HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const MatrixType&
|
||||||
|
|
||||||
m_temp.resize(cols);
|
m_temp.resize(cols);
|
||||||
|
|
||||||
internal::householder_qr_inplace_blocked(m_qr, m_hCoeffs, 48, m_temp.data());
|
internal::householder_qr_inplace_blocked<MatrixType, HCoeffsType>::run(m_qr, m_hCoeffs, 48, m_temp.data());
|
||||||
|
|
||||||
m_isInitialized = true;
|
m_isInitialized = true;
|
||||||
return *this;
|
return *this;
|
||||||
|
|
|
||||||
|
|
@ -38,24 +38,26 @@
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
/** \internal Specialization for the data types supported by MKL */
|
/** \internal Specialization for the data types supported by MKL */
|
||||||
|
|
||||||
#define EIGEN_MKL_QR_NOPIV(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
#define EIGEN_MKL_QR_NOPIV(EIGTYPE, MKLTYPE, MKLPREFIX) \
|
||||||
template<typename MatrixQR, typename HCoeffs> \
|
template<typename MatrixQR, typename HCoeffs> \
|
||||||
void householder_qr_inplace_blocked(MatrixQR& mat, HCoeffs& hCoeffs, \
|
struct householder_qr_inplace_blocked<MatrixQR, HCoeffs, EIGTYPE> \
|
||||||
typename MatrixQR::Index maxBlockSize=32, \
|
|
||||||
EIGTYPE* tempData = 0) \
|
|
||||||
{ \
|
{ \
|
||||||
lapack_int m = (lapack_int) mat.rows(); \
|
static void run(MatrixQR& mat, HCoeffs& hCoeffs, \
|
||||||
lapack_int n = (lapack_int) mat.cols(); \
|
typename MatrixQR::Index = 32, \
|
||||||
lapack_int lda = (lapack_int) mat.outerStride(); \
|
typename MatrixQR::Scalar* = 0) \
|
||||||
lapack_int matrix_order = (MatrixQR::IsRowMajor) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \
|
{ \
|
||||||
LAPACKE_##MKLPREFIX##geqrf( matrix_order, m, n, (MKLTYPE*)mat.data(), lda, (MKLTYPE*)hCoeffs.data()); \
|
lapack_int m = (lapack_int) mat.rows(); \
|
||||||
hCoeffs.adjointInPlace(); \
|
lapack_int n = (lapack_int) mat.cols(); \
|
||||||
\
|
lapack_int lda = (lapack_int) mat.outerStride(); \
|
||||||
}
|
lapack_int matrix_order = (MatrixQR::IsRowMajor) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \
|
||||||
|
LAPACKE_##MKLPREFIX##geqrf( matrix_order, m, n, (MKLTYPE*)mat.data(), lda, (MKLTYPE*)hCoeffs.data()); \
|
||||||
|
hCoeffs.adjointInPlace(); \
|
||||||
|
} \
|
||||||
|
};
|
||||||
|
|
||||||
EIGEN_MKL_QR_NOPIV(double, double, d)
|
EIGEN_MKL_QR_NOPIV(double, double, d)
|
||||||
EIGEN_MKL_QR_NOPIV(float, float, s)
|
EIGEN_MKL_QR_NOPIV(float, float, s)
|
||||||
|
|
|
||||||
|
|
@ -336,7 +336,7 @@ void inplace_QR(MATRIX& A) {
|
||||||
HCoeffsType hCoeffs(size);
|
HCoeffsType hCoeffs(size);
|
||||||
RowVectorType temp(cols);
|
RowVectorType temp(cols);
|
||||||
|
|
||||||
Eigen::internal::householder_qr_inplace_blocked(A, hCoeffs, 48, temp.data());
|
Eigen::internal::householder_qr_inplace_blocked<MATRIX, HCoeffsType>::run(A, hCoeffs, 48, temp.data());
|
||||||
|
|
||||||
zeroBelowDiagonal(A);
|
zeroBelowDiagonal(A);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue