fix NaN problem

release/4.3a0
Kai Ni 2010-10-31 08:00:30 +00:00
parent 448ada270a
commit 69760f7a4c
1 changed files with 66 additions and 9 deletions

View File

@ -10,10 +10,11 @@
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include "DenseQR.h"
//#define DEBUG_MEMORY
#define DEBUG_MEMORY
// all the lapack functions we need here
extern "C" {
@ -33,6 +34,40 @@ namespace gtsam {
static char storev = 'C';
static char trans = 'T';
/* ************************************************************************* */
// check NaN in the input matrix
void CheckNaN(int m, int n, double *A, const char* msg) {
bool hasNaN = false;
for(int i=0; i<m; i++) {
for(int j=0; j<n; j++)
if (isnan(A[j*m+i]))
throw std::invalid_argument(msg);
}
}
/* ************************************************************************* */
// remove NaN in the input matrix
void RemoveNaN(int m, int n, double *A) {
bool hasNaN = false;
for(int i=0; i<m; i++) {
for(int j=0; j<n; j++)
if (isnan(A[j*m+i]))
A[j*m+i] = 0;
}
}
/* ************************************************************************* */
// check NaN in the input matrix
bool HasNaN(int m, int n, double *A) {
bool hasNaN = false;
for(int i=0; i<m; i++) {
for(int j=0; j<n; j++)
if (isnan(A[j*m+i]))
return true;
}
return false;
}
/* ************************************************************************* */
/**
* the wrapper for LAPACK dlarft_ and dlarfb_ function
@ -59,8 +94,26 @@ namespace gtsam {
|| sizeBlock <= 1;
}
/* ************************************************************************* */
// write input to the disk for debugging
void WriteInput(int m, int n, int numPivotColumns, double *A, int *stairs) {
ofstream fs;
fs.open ("denseqr.txt", ios::out | ios::trunc);
fs << m << " " << n << " " << numPivotColumns << endl;
for(int i=0; i<m; i++) {
for(int j=0; j<n; j++)
fs << A[j*m+i] << "\t";
fs << endl;
}
fs.close();
}
/* ************************************************************************* */
void DenseQR(int m, int n, int numPivotColumns, double *A, int *stairs, double *workspace) {
// WriteInput(m, n, numPivotColumns, A, stairs);
// CheckNaN(m, n, A, "DenseQR: the input matrix has NaN");
if (A == NULL) throw std::invalid_argument("DenseQR: A == NULL!");
if (stairs == NULL) throw std::invalid_argument("DenseQR: stairs == NULL!");
if (workspace == NULL) throw std::invalid_argument("DenseQR: W == NULL!");
@ -69,9 +122,9 @@ namespace gtsam {
if (numPivotColumns < 0 || numPivotColumns > n)
throw std::invalid_argument("DenseQR: numPivotColumns < 0l || numPivotColumns > n");
double tau[n]; // the scalar in Householder
double tau, Tau[n]; // the scalar in Householder
int row1stHH = 0, numGoodHHs = 0, numPendingHHs = 0;
int colPendingHHEnd = 0;
int colPendingHHStart = 0, colPendingHHEnd = 0;
double *vectorHH = A;
int numZeros = 0;
int sizeBlock = m < 32 ? m : 32;
@ -97,16 +150,18 @@ namespace gtsam {
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
#endif
dlarftb_wrap(stairStartLast - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
}
// compute Householder for the current column
int n_ = stairStart - numGoodHHs;
double *X = &A[numGoodHHs+col*m];
dlarfg_(&n_, X, X + 1, &one, tau);
dlarfg_(&n_, X, X + 1, &one, &tau);
Tau[col] = tau;
if (!numPendingHHs) {
row1stHH = numGoodHHs;
vectorHH = &A[row1stHH+col*m];
colPendingHHStart = col;
#ifdef DEBUG_MEMORY
if (row1stHH+col*m >= m*n) throw std::runtime_error("DenseQR: row1stHH+col*m >= m*n");
#endif
@ -120,7 +175,7 @@ namespace gtsam {
if (m1 > 0 && n1 > 0) {
double *A1 = &A[numGoodHHs+col*m], *C1 = A1 + m, v0 = *A1;
*A1 = 1 ;
dlarf_ (&left, &m1, &n1, A1, &one, tau, C1, &m, workspace) ;
dlarf_ (&left, &m1, &n1, A1, &one, &tau, C1, &m, workspace) ;
*A1 = v0;
numGoodHHs++;
}
@ -131,9 +186,11 @@ namespace gtsam {
if (colPendingHHEnd >= n) throw std::runtime_error("DenseQR: colPendingHHEnd >= n");
#endif
dlarftb_wrap(stairStart - row1stHH, n - colPendingHHEnd, numPendingHHs, m, m,
vectorHH, tau, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
}
vectorHH, Tau + colPendingHHStart, &A[row1stHH+colPendingHHEnd*m], workspace, &numPendingHHs, &numZeros);
}
}
// CheckNaN(m, n, A, "DenseQR: the output matrix has NaN");
}
} // namespace gtsam