fix NaN problem
parent
448ada270a
commit
69760f7a4c
|
@ -10,10 +10,11 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "DenseQR.h"
|
#include "DenseQR.h"
|
||||||
|
|
||||||
//#define DEBUG_MEMORY
|
#define DEBUG_MEMORY
|
||||||
|
|
||||||
// all the lapack functions we need here
|
// all the lapack functions we need here
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -33,6 +34,40 @@ namespace gtsam {
|
||||||
static char storev = 'C';
|
static char storev = 'C';
|
||||||
static char trans = 'T';
|
static char trans = 'T';
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// check NaN in the input matrix
|
||||||
|
void CheckNaN(int m, int n, double *A, const char* msg) {
|
||||||
|
bool hasNaN = false;
|
||||||
|
for(int i=0; i<m; i++) {
|
||||||
|
for(int j=0; j<n; j++)
|
||||||
|
if (isnan(A[j*m+i]))
|
||||||
|
throw std::invalid_argument(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// remove NaN in the input matrix
|
||||||
|
void RemoveNaN(int m, int n, double *A) {
|
||||||
|
bool hasNaN = false;
|
||||||
|
for(int i=0; i<m; i++) {
|
||||||
|
for(int j=0; j<n; j++)
|
||||||
|
if (isnan(A[j*m+i]))
|
||||||
|
A[j*m+i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// check NaN in the input matrix
|
||||||
|
bool HasNaN(int m, int n, double *A) {
|
||||||
|
bool hasNaN = false;
|
||||||
|
for(int i=0; i<m; i++) {
|
||||||
|
for(int j=0; j<n; j++)
|
||||||
|
if (isnan(A[j*m+i]))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
/**
|
||||||
* the wrapper for LAPACK dlarft_ and dlarfb_ function
|
* the wrapper for LAPACK dlarft_ and dlarfb_ function
|
||||||
|
@ -59,8 +94,26 @@ namespace gtsam {
|
||||||
|| sizeBlock <= 1;
|
|| sizeBlock <= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// write input to the disk for debugging
|
||||||
|
void WriteInput(int m, int n, int numPivotColumns, double *A, int *stairs) {
|
||||||
|
ofstream fs;
|
||||||
|
fs.open ("denseqr.txt", ios::out | ios::trunc);
|
||||||
|
fs << m << " " << n << " " << numPivotColumns << endl;
|
||||||
|
for(int i=0; i<m; i++) {
|
||||||
|
for(int j=0; j<n; j++)
|
||||||
|
fs << A[j*m+i] << "\t";
|
||||||
|
fs << endl;
|
||||||
|
}
|
||||||
|
fs.close();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void DenseQR(int m, int n, int numPivotColumns, double *A, int *stairs, double *workspace) {
|
void DenseQR(int m, int n, int numPivotColumns, double *A, int *stairs, double *workspace) {
|
||||||
|
|
||||||
|
// WriteInput(m, n, numPivotColumns, A, stairs);
|
||||||
|
// CheckNaN(m, n, A, "DenseQR: the input matrix has NaN");
|
||||||
|
|
||||||
if (A == NULL) throw std::invalid_argument("DenseQR: A == NULL!");
|
if (A == NULL) throw std::invalid_argument("DenseQR: A == NULL!");
|
||||||
if (stairs == NULL) throw std::invalid_argument("DenseQR: stairs == NULL!");
|
if (stairs == NULL) throw std::invalid_argument("DenseQR: stairs == NULL!");
|
||||||
if (workspace == NULL) throw std::invalid_argument("DenseQR: W == NULL!");
|
if (workspace == NULL) throw std::invalid_argument("DenseQR: W == NULL!");
|
||||||
|
@ -69,9 +122,9 @@ namespace gtsam {
|
||||||
if (numPivotColumns < 0 || numPivotColumns > n)
|
if (numPivotColumns < 0 || numPivotColumns > n)
|
||||||
throw std::invalid_argument("DenseQR: numPivotColumns < 0l || numPivotColumns > n");
|
throw std::invalid_argument("DenseQR: numPivotColumns < 0l || numPivotColumns > n");
|
||||||
|
|
||||||
double tau[n]; // the scalar in Householder
|
double tau, Tau[n]; // the scalar in Householder
|
||||||
int row1stHH = 0, numGoodHHs = 0, numPendingHHs = 0;
|
int row1stHH = 0, numGoodHHs = 0, numPendingHHs = 0;
|
||||||
int colPendingHHEnd = 0;
|
int colPendingHHStart = 0, colPendingHHEnd = 0;
|
||||||
double *vectorHH = A;
|
double *vectorHH = A;
|
||||||
int numZeros = 0;
|
int numZeros = 0;
|
||||||
int sizeBlock = m < 32 ? m : 32;
|
int sizeBlock = m < 32 ? m : 32;
|
||||||
|
@ -97,16 +150,18 @@ namespace gtsam {
|
||||||
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
||||||
#endif
|
#endif
|
||||||
dlarftb_wrap(stairStartLast - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
dlarftb_wrap(stairStartLast - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||||
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute Householder for the current column
|
// compute Householder for the current column
|
||||||
int n_ = stairStart - numGoodHHs;
|
int n_ = stairStart - numGoodHHs;
|
||||||
double *X = &A[numGoodHHs+col*m];
|
double *X = &A[numGoodHHs+col*m];
|
||||||
dlarfg_(&n_, X, X + 1, &one, tau);
|
dlarfg_(&n_, X, X + 1, &one, &tau);
|
||||||
|
Tau[col] = tau;
|
||||||
if (!numPendingHHs) {
|
if (!numPendingHHs) {
|
||||||
row1stHH = numGoodHHs;
|
row1stHH = numGoodHHs;
|
||||||
vectorHH = &A[row1stHH+col*m];
|
vectorHH = &A[row1stHH+col*m];
|
||||||
|
colPendingHHStart = col;
|
||||||
#ifdef DEBUG_MEMORY
|
#ifdef DEBUG_MEMORY
|
||||||
if (row1stHH+col*m >= m*n) throw std::runtime_error("DenseQR: row1stHH+col*m >= m*n");
|
if (row1stHH+col*m >= m*n) throw std::runtime_error("DenseQR: row1stHH+col*m >= m*n");
|
||||||
#endif
|
#endif
|
||||||
|
@ -120,7 +175,7 @@ namespace gtsam {
|
||||||
if (m1 > 0 && n1 > 0) {
|
if (m1 > 0 && n1 > 0) {
|
||||||
double *A1 = &A[numGoodHHs+col*m], *C1 = A1 + m, v0 = *A1;
|
double *A1 = &A[numGoodHHs+col*m], *C1 = A1 + m, v0 = *A1;
|
||||||
*A1 = 1 ;
|
*A1 = 1 ;
|
||||||
dlarf_ (&left, &m1, &n1, A1, &one, tau, C1, &m, workspace) ;
|
dlarf_ (&left, &m1, &n1, A1, &one, &tau, C1, &m, workspace) ;
|
||||||
*A1 = v0;
|
*A1 = v0;
|
||||||
numGoodHHs++;
|
numGoodHHs++;
|
||||||
}
|
}
|
||||||
|
@ -131,9 +186,11 @@ namespace gtsam {
|
||||||
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
|
||||||
#endif
|
#endif
|
||||||
dlarftb_wrap(stairStart - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
dlarftb_wrap(stairStart - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
|
||||||
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckNaN(m, n, A, "DenseQR: the output matrix has NaN");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue