fix NaN problem
parent
448ada270a
commit
69760f7a4c
|
@ -10,10 +10,11 @@
|
|||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "DenseQR.h"
|
||||
|
||||
//#define DEBUG_MEMORY
|
||||
#define DEBUG_MEMORY
|
||||
|
||||
// all the lapack functions we need here
|
||||
extern "C" {
|
||||
|
@ -33,6 +34,40 @@ namespace gtsam {
|
|||
static char storev = 'C';
|
||||
static char trans = 'T';
|
||||
|
||||
/* ************************************************************************* */
|
||||
// check NaN in the input matrix
|
||||
void CheckNaN(int m, int n, double *A, const char* msg) {
|
||||
bool hasNaN = false;
|
||||
for(int i=0; i<m; i++) {
|
||||
for(int j=0; j<n; j++)
|
||||
if (isnan(A[j*m+i]))
|
||||
throw std::invalid_argument(msg);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// remove NaN in the input matrix
|
||||
void RemoveNaN(int m, int n, double *A) {
|
||||
bool hasNaN = false;
|
||||
for(int i=0; i<m; i++) {
|
||||
for(int j=0; j<n; j++)
|
||||
if (isnan(A[j*m+i]))
|
||||
A[j*m+i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// check NaN in the input matrix
|
||||
bool HasNaN(int m, int n, double *A) {
|
||||
bool hasNaN = false;
|
||||
for(int i=0; i<m; i++) {
|
||||
for(int j=0; j<n; j++)
|
||||
if (isnan(A[j*m+i]))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* the wrapper for LAPACK dlarft_ and dlarfb_ function
|
||||
|
@ -59,8 +94,26 @@ namespace gtsam {
|
|||
|| sizeBlock <= 1;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// write input to the disk for debugging
|
||||
void WriteInput(int m, int n, int numPivotColumns, double *A, int *stairs) {
|
||||
ofstream fs;
|
||||
fs.open ("denseqr.txt", ios::out | ios::trunc);
|
||||
fs << m << " " << n << " " << numPivotColumns << endl;
|
||||
for(int i=0; i<m; i++) {
|
||||
for(int j=0; j<n; j++)
|
||||
fs << A[j*m+i] << "\t";
|
||||
fs << endl;
|
||||
}
|
||||
fs.close();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DenseQR(int m, int n, int numPivotColumns, double *A, int *stairs, double *workspace) {
|
||||
|
||||
// WriteInput(m, n, numPivotColumns, A, stairs);
|
||||
// CheckNaN(m, n, A, "DenseQR: the input matrix has NaN");
|
||||
|
||||
if (A == NULL) throw std::invalid_argument("DenseQR: A == NULL!");
|
||||
if (stairs == NULL) throw std::invalid_argument("DenseQR: stairs == NULL!");
|
||||
if (workspace == NULL) throw std::invalid_argument("DenseQR: W == NULL!");
|
||||
|
@ -69,9 +122,9 @@ namespace gtsam {
|
|||
if (numPivotColumns < 0 || numPivotColumns > n)
|
||||
throw std::invalid_argument("DenseQR: numPivotColumns < 0l || numPivotColumns > n");
|
||||
|
||||
double tau[n]; // the scalar in Householder
|
||||
double tau, Tau[n]; // the scalar in Householder
|
||||
int row1stHH = 0, numGoodHHs = 0, numPendingHHs = 0;
|
||||
int colPendingHHEnd = 0;
|
||||
int colPendingHHStart = 0, colPendingHHEnd = 0;
|
||||
double *vectorHH = A;
|
||||
int numZeros = 0;
|
||||
int sizeBlock = m < 32 ? m : 32;
|
||||
|
@ -96,17 +149,19 @@ namespace gtsam {
|
|||
if (row1stHH >= m) throw std::runtime_error("DenseQR: row1stHH >= m");
|
||||
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
||||
#endif
|
||||
dlarftb_wrap(stairStartLast - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||
dlarftb_wrap(stairStartLast - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||
}
|
||||
|
||||
// compute Householder for the current column
|
||||
int n_ = stairStart - numGoodHHs;
|
||||
double *X = &A[numGoodHHs+col*m];
|
||||
dlarfg_(&n_, X, X + 1, &one, tau);
|
||||
dlarfg_(&n_, X, X + 1, &one, &tau);
|
||||
Tau[col] = tau;
|
||||
if (!numPendingHHs) {
|
||||
row1stHH = numGoodHHs;
|
||||
vectorHH = &A[row1stHH+col*m];
|
||||
colPendingHHStart = col;
|
||||
#ifdef DEBUG_MEMORY
|
||||
if (row1stHH+col*m >= m*n) throw std::runtime_error("DenseQR: row1stHH+col*m >= m*n");
|
||||
#endif
|
||||
|
@ -120,7 +175,7 @@ namespace gtsam {
|
|||
if (m1 > 0 && n1 > 0) {
|
||||
double *A1 = &A[numGoodHHs+col*m], *C1 = A1 + m, v0 = *A1;
|
||||
*A1 = 1 ;
|
||||
dlarf_ (&left, &m1, &n1, A1, &one, tau, C1, &m, workspace) ;
|
||||
dlarf_ (&left, &m1, &n1, A1, &one, &tau, C1, &m, workspace) ;
|
||||
*A1 = v0;
|
||||
numGoodHHs++;
|
||||
}
|
||||
|
@ -130,10 +185,12 @@ namespace gtsam {
|
|||
if (row1stHH >= m) throw std::runtime_error("DenseQR: row1stHH >= m");
|
||||
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
||||
#endif
|
||||
dlarftb_wrap(stairStart - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||
dlarftb_wrap(stairStart - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||
}
|
||||
}
|
||||
|
||||
// CheckNaN(m, n, A, "DenseQR: the output matrix has NaN");
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
Loading…
Reference in New Issue