tls done except unit tests

release/4.3a0
jingnanshi 2020-12-07 13:24:49 -05:00
parent 47775a7a4f
commit 9903fb91d0
2 changed files with 30 additions and 5 deletions

View File

@ -65,6 +65,7 @@ public:
size_t maxIterations = 100; /* maximum number of iterations*/
double barcSq = 1.0; /* a factor is considered an inlier if factor.error() < barcSq. Note that factor.error() whitens by the covariance*/
double muStep = 1.4; /* multiplicative factor to reduce/increase the mu in gnc */
double relativeMuTol = 1e-5; ///< The maximum relative mu decrease to stop iterating
VerbosityGNC verbosityGNC = SILENT; /* verbosity level */
std::vector<size_t> knownInliers = std::vector<size_t>(); /* slots in the factor graph corresponding to measurements that we know are inliers */
@ -89,6 +90,10 @@ public:
void setMuStep(const double step) {
muStep = step;
}
/// Set the maximum relative difference in mu values to stop iterating
void setRelativeMuTol(double value) {
relativeMuTol = value;
}
/// Set the verbosity level
void setVerbosityGNC(const VerbosityGNC verbosity) {
verbosityGNC = verbosity;
@ -196,6 +201,7 @@ public:
GaussNewtonOptimizer baseOptimizer(nfg_, state_);
Values result = baseOptimizer.optimize();
double mu = initializeMu();
double mu_prev = mu;
// handle the degenerate case for TLS cost that corresponds to small
// maximum residual error at initialization
@ -225,7 +231,7 @@ public:
result = baseOptimizer_iter.optimize();
// stopping condition
if (checkMuConvergence(mu)) {
if (checkMuConvergence(mu, mu_prev)) {
// display info
if (params_.verbosityGNC >= GncParameters::VerbosityGNC::SUMMARY) {
std::cout << "final iterations: " << iter << std::endl;
@ -235,6 +241,7 @@ public:
break;
}
// otherwise update mu
mu_prev = mu;
mu = updateMu(mu);
}
return result;
@ -279,11 +286,12 @@ public:
}
/// check if we have reached the value of mu for which the surrogate loss matches the original loss
bool checkMuConvergence(const double mu) const {
bool checkMuConvergence(const double mu, const double mu_prev) const {
switch (params_.lossType) {
case GncParameters::GM:
return std::fabs(mu - 1.0) < 1e-9; // mu=1 recovers the original GM function
// TODO: Add TLS
case GncParameters::TLS:
return std::fabs(mu - mu_prev) < params_.relativeMuTol;
default:
throw std::runtime_error(
"GncOptimizer::checkMuConvergence: called with unknown loss type.");
@ -341,7 +349,22 @@ public:
}
}
return weights;
// TODO: Add TLS
case GncParameters::TLS: // use eq (14) in GNC paper
double upperbound = (mu + 1) / mu * params_.barcSq;
double lowerbound = mu / (mu +1 ) * params_.barcSq;
for (size_t k : unknownWeights) {
if (nfg_[k]) {
double u2_k = nfg_[k]->error(currentEstimate); // squared (and whitened) residual
if (u2_k >= upperbound ) {
weights[k] = 0;
} else if (u2_k <= lowerbound) {
weights[k] = 1;
} else {
weights[k] = std::sqrt(params_.barcSq * mu * (mu + 1) / u2_k ) - mu;
}
}
}
return weights;
default:
throw std::runtime_error(
"GncOptimizer::calculateWeights: called with unknown loss type.");

View File

@ -162,7 +162,9 @@ TEST(GncOptimizer, checkMuConvergence) {
gncParams);
double mu = 1.0;
CHECK(gnc.checkMuConvergence(mu));
CHECK(gnc.checkMuConvergence(mu, 0));
// TODO: test relative mu convergence
}
/* ************************************************************************* */