diff --git a/gtsam/nonlinear/GncOptimizer.h b/gtsam/nonlinear/GncOptimizer.h index be7d046c6..b6e6933ec 100644 --- a/gtsam/nonlinear/GncOptimizer.h +++ b/gtsam/nonlinear/GncOptimizer.h @@ -65,6 +65,7 @@ public: size_t maxIterations = 100; /* maximum number of iterations*/ double barcSq = 1.0; /* a factor is considered an inlier if factor.error() < barcSq. Note that factor.error() whitens by the covariance*/ double muStep = 1.4; /* multiplicative factor to reduce/increase the mu in gnc */ + double relativeMuTol = 1e-5; ///< The maximum relative mu decrease to stop iterating VerbosityGNC verbosityGNC = SILENT; /* verbosity level */ std::vector knownInliers = std::vector(); /* slots in the factor graph corresponding to measurements that we know are inliers */ @@ -89,6 +90,10 @@ public: void setMuStep(const double step) { muStep = step; } + /// Set the maximum relative difference in mu values to stop iterating + void setRelativeMuTol(double value) { + relativeMuTol = value; + } /// Set the verbosity level void setVerbosityGNC(const VerbosityGNC verbosity) { verbosityGNC = verbosity; @@ -196,6 +201,7 @@ public: GaussNewtonOptimizer baseOptimizer(nfg_, state_); Values result = baseOptimizer.optimize(); double mu = initializeMu(); + double mu_prev = mu; // handle the degenerate case for TLS cost that corresponds to small // maximum residual error at initialization @@ -225,7 +231,7 @@ public: result = baseOptimizer_iter.optimize(); // stopping condition - if (checkMuConvergence(mu)) { + if (checkMuConvergence(mu, mu_prev)) { // display info if (params_.verbosityGNC >= GncParameters::VerbosityGNC::SUMMARY) { std::cout << "final iterations: " << iter << std::endl; @@ -235,6 +241,7 @@ public: break; } // otherwise update mu + mu_prev = mu; mu = updateMu(mu); } return result; @@ -279,11 +286,12 @@ public: } /// check if we have reached the value of mu for which the surrogate loss matches the original loss - bool checkMuConvergence(const double mu) const { + bool checkMuConvergence(const double mu, const double mu_prev) const { switch (params_.lossType) { case GncParameters::GM: return std::fabs(mu - 1.0) < 1e-9; // mu=1 recovers the original GM function - // TODO: Add TLS + case GncParameters::TLS: + return std::fabs(mu - mu_prev) < params_.relativeMuTol; default: throw std::runtime_error( "GncOptimizer::checkMuConvergence: called with unknown loss type."); @@ -341,7 +349,22 @@ public: } } return weights; - // TODO: Add TLS + case GncParameters::TLS: // use eq (14) in GNC paper + double upperbound = (mu + 1) / mu * params_.barcSq; + double lowerbound = mu / (mu +1 ) * params_.barcSq; + for (size_t k : unknownWeights) { + if (nfg_[k]) { + double u2_k = nfg_[k]->error(currentEstimate); // squared (and whitened) residual + if (u2_k >= upperbound ) { + weights[k] = 0; + } else if (u2_k <= lowerbound) { + weights[k] = 1; + } else { + weights[k] = std::sqrt(params_.barcSq * mu * (mu + 1) / u2_k ) - mu; + } + } + } + return weights; default: throw std::runtime_error( "GncOptimizer::calculateWeights: called with unknown loss type."); diff --git a/tests/testGncOptimizer.cpp b/tests/testGncOptimizer.cpp index 5006aa941..3f784b96e 100644 --- a/tests/testGncOptimizer.cpp +++ b/tests/testGncOptimizer.cpp @@ -162,7 +162,9 @@ TEST(GncOptimizer, checkMuConvergence) { gncParams); double mu = 1.0; - CHECK(gnc.checkMuConvergence(mu)); + CHECK(gnc.checkMuConvergence(mu, 0)); + + // TODO: test relative mu convergence } /* ************************************************************************* */