diff --git a/gtsam/nonlinear/GncOptimizer.h b/gtsam/nonlinear/GncOptimizer.h index dd6c8d17d..d4b25409f 100644 --- a/gtsam/nonlinear/GncOptimizer.h +++ b/gtsam/nonlinear/GncOptimizer.h @@ -264,11 +264,12 @@ public: // set initial mu switch (params_.lossType) { case GncParameters::GM: + // surrogate cost is convex for large mu return 2 * rmax_sq / params_.barcSq; // initial mu case GncParameters::TLS: - // initialize mu to the value specified in Remark 5 in GNC paper - // degenerate case: residual is close to zero so mu approximately equals - // to -1 + // initialize mu to the value specified in Remark 5 in GNC paper. + // surrogate cost is convex for mu close to zero + // degenerate case: 2 * rmax_sq - params_.barcSq < 0 (handled in the main loop) return params_.barcSq / (2 * rmax_sq - params_.barcSq); default: throw std::runtime_error( @@ -280,9 +281,10 @@ public: double updateMu(const double mu) const { switch (params_.lossType) { case GncParameters::GM: - return std::max(1.0, mu / params_.muStep); // reduce mu, but saturate at 1 + // reduce mu, but saturate at 1 (original cost is recovered for mu -> 1) + return std::max(1.0, mu / params_.muStep); case GncParameters::TLS: - // increases mu at each iteration + // increases mu at each iteration (original cost is recovered for mu -> inf) return mu * params_.muStep; default: throw std::runtime_error( diff --git a/tests/testGncOptimizer.cpp b/tests/testGncOptimizer.cpp index b3bef11e0..85a196bd7 100644 --- a/tests/testGncOptimizer.cpp +++ b/tests/testGncOptimizer.cpp @@ -29,6 +29,7 @@ #include #include +#include #include #include @@ -285,6 +286,75 @@ TEST(GncOptimizer, calculateWeightsTLS) { CHECK(assert_equal(weights_expected, weights_actual, tol)); } +/* ************************************************************************* */ +TEST(GncOptimizer, calculateWeightsTLS2) { + + // create values + Point2 x_val(0.0, 0.0); + Point2 x_prior(1.0, 0.0); + Values initial; + initial.insert(X(1), x_val); + + // create very simple factor graph with a single factor 0.5 * 1/sigma^2 * || x - [1;0] ||^2 + double sigma = 1; + SharedDiagonal noise = + noiseModel::Diagonal::Sigmas(Vector2(sigma, sigma)); + NonlinearFactorGraph nfg; + nfg.add(PriorFactor(X(1),x_prior,noise)); + + // cost of the factor: + DOUBLES_EQUAL(0.5 * 1/(sigma*sigma), nfg.error(initial), tol); + + // check the TLS weights are correct: CASE 1: residual below barcsq + { + // expected: + Vector weights_expected = Vector::Zero(1); + weights_expected[0] = 1.0; // inlier + // actual: + GaussNewtonParams gnParams; + GncParams gncParams(gnParams); + gncParams.setLossType( + GncParams::RobustLossType::TLS); + gncParams.setInlierThreshold(0.51); // if inlier threshold is slightly larger than 0.5, then measurement is inlier + auto gnc = GncOptimizer>(nfg, initial, gncParams); + double mu = 1e6; + Vector weights_actual = gnc.calculateWeights(initial, mu); + CHECK(assert_equal(weights_expected, weights_actual, tol)); + } + // check the TLS weights are correct: CASE 2: residual above barcsq + { + // expected: + Vector weights_expected = Vector::Zero(1); + weights_expected[0] = 0.0; // outlier + // actual: + GaussNewtonParams gnParams; + GncParams gncParams(gnParams); + gncParams.setLossType( + GncParams::RobustLossType::TLS); + gncParams.setInlierThreshold(0.49); // if inlier threshold is slightly below 0.5, then measurement is outlier + auto gnc = GncOptimizer>(nfg, initial, gncParams); + double mu = 1e6; // very large mu recovers original TLS cost + Vector weights_actual = gnc.calculateWeights(initial, mu); + CHECK(assert_equal(weights_expected, weights_actual, tol)); + } + // check the TLS weights are correct: CASE 2: residual at barcsq + { + // expected: + Vector weights_expected = Vector::Zero(1); + weights_expected[0] = 0.5; // undecided + // actual: + GaussNewtonParams gnParams; + GncParams gncParams(gnParams); + gncParams.setLossType( + GncParams::RobustLossType::TLS); + gncParams.setInlierThreshold(0.5); // if inlier threshold is slightly below 0.5, then measurement is outlier + auto gnc = GncOptimizer>(nfg, initial, gncParams); + double mu = 1e6; // very large mu recovers original TLS cost + Vector weights_actual = gnc.calculateWeights(initial, mu); + CHECK(assert_equal(weights_expected, weights_actual, 1e-5)); + } +} + /* ************************************************************************* */ TEST(GncOptimizer, makeWeightedGraph) { // create original factor