Switched to in-place update of the diagonal Hessian

release/4.3a0
Fan Jiang 2020-06-02 12:44:57 -04:00
parent 65da699e57
commit f73429133a
6 changed files with 32 additions and 4 deletions

View File

@ -102,6 +102,9 @@ namespace gtsam {
/// Return the diagonal of the Hessian for this factor /// Return the diagonal of the Hessian for this factor
virtual VectorValues hessianDiagonal() const = 0; virtual VectorValues hessianDiagonal() const = 0;
/// Add the current diagonal to a VectorValues instance
virtual void hessianDiagonalAdd(VectorValues& d) const = 0;
/// Raw memory access version of hessianDiagonal /// Raw memory access version of hessianDiagonal
virtual void hessianDiagonal(double* d) const = 0; virtual void hessianDiagonal(double* d) const = 0;

View File

@ -255,8 +255,7 @@ namespace gtsam {
VectorValues d; VectorValues d;
for (const sharedFactor& factor : *this) { for (const sharedFactor& factor : *this) {
if(factor){ if(factor){
VectorValues di = factor->hessianDiagonal(); factor->hessianDiagonalAdd(d);
d.addInPlace_(di);
} }
} }
return d; return d;

View File

@ -310,6 +310,17 @@ VectorValues HessianFactor::hessianDiagonal() const {
return d; return d;
} }
/* ************************************************************************* */
void HessianFactor::hessianDiagonalAdd(VectorValues &d) const {
for (DenseIndex j = 0; j < (DenseIndex)size(); ++j) {
if(d.exists(keys_[j])) {
d.at(keys_[j]) += info_.diagonal(j);
} else {
d.emplace(keys_[j], info_.diagonal(j));
}
}
}
/* ************************************************************************* */ /* ************************************************************************* */
// Raw memory access version should be called in Regular Factors only currently // Raw memory access version should be called in Regular Factors only currently
void HessianFactor::hessianDiagonal(double* d) const { void HessianFactor::hessianDiagonal(double* d) const {

View File

@ -296,6 +296,9 @@ namespace gtsam {
/// Return the diagonal of the Hessian for this factor /// Return the diagonal of the Hessian for this factor
VectorValues hessianDiagonal() const override; VectorValues hessianDiagonal() const override;
/// Add the current diagonal to a VectorValues instance
void hessianDiagonalAdd(VectorValues& d) const override;
/// Raw memory access version of hessianDiagonal /// Raw memory access version of hessianDiagonal
void hessianDiagonal(double* d) const override; void hessianDiagonal(double* d) const override;

View File

@ -544,6 +544,12 @@ Matrix JacobianFactor::information() const {
/* ************************************************************************* */ /* ************************************************************************* */
VectorValues JacobianFactor::hessianDiagonal() const { VectorValues JacobianFactor::hessianDiagonal() const {
VectorValues d; VectorValues d;
hessianDiagonalAdd(d);
return d;
}
/* ************************************************************************* */
void JacobianFactor::hessianDiagonalAdd(VectorValues& d) const {
for (size_t pos = 0; pos < size(); ++pos) { for (size_t pos = 0; pos < size(); ++pos) {
Key j = keys_[pos]; Key j = keys_[pos];
size_t nj = Ab_(pos).cols(); size_t nj = Ab_(pos).cols();
@ -554,9 +560,12 @@ VectorValues JacobianFactor::hessianDiagonal() const {
model_->whitenInPlace(column_k); model_->whitenInPlace(column_k);
dj(k) = dot(column_k, column_k); dj(k) = dot(column_k, column_k);
} }
d.emplace(j, dj); if(d.exists(j)) {
d.at(j) += dj;
} else {
d.emplace(j, dj);
}
} }
return d;
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -218,6 +218,9 @@ namespace gtsam {
/// Return the diagonal of the Hessian for this factor /// Return the diagonal of the Hessian for this factor
VectorValues hessianDiagonal() const override; VectorValues hessianDiagonal() const override;
/// Add the current diagonal to a VectorValues instance
void hessianDiagonalAdd(VectorValues& d) const override;
/// Raw memory access version of hessianDiagonal /// Raw memory access version of hessianDiagonal
void hessianDiagonal(double* d) const override; void hessianDiagonal(double* d) const override;