From 8e5f1447e3d57b4722f1089b1fc5152e966cfc00 Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Sat, 11 Jul 2020 11:54:40 -0400 Subject: [PATCH] Add check to ensure we are calling lambda on a LM --- cython/gtsam/tests/test_logging_optimizer.py | 15 +++++++++++++++ cython/gtsam/utils/logging_optimizer.py | 3 ++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/cython/gtsam/tests/test_logging_optimizer.py b/cython/gtsam/tests/test_logging_optimizer.py index 69665db65..2560a72a2 100644 --- a/cython/gtsam/tests/test_logging_optimizer.py +++ b/cython/gtsam/tests/test_logging_optimizer.py @@ -43,6 +43,11 @@ class TestOptimizeComet(GtsamTestCase): self.optimizer = gtsam.GaussNewtonOptimizer( graph, initial, self.params) + self.lmparams = gtsam.LevenbergMarquardtParams() + self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer( + graph, initial, self.lmparams + ) + # setup output capture self.capturedOutput = StringIO() sys.stdout = self.capturedOutput @@ -65,6 +70,16 @@ class TestOptimizeComet(GtsamTestCase): actual = self.optimizer.values() self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + def test_lm_simple_printing(self): + """Make sure we are properly terminating LM""" + def hook(_, error): + print(error) + + gtsam_optimize(self.lmoptimizer, self.lmparams, hook) + + actual = self.lmoptimizer.values() + self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + @unittest.skip("Not a test we want run every time, as needs comet.ml account") def test_comet(self): """Test with a comet hook.""" diff --git a/cython/gtsam/utils/logging_optimizer.py b/cython/gtsam/utils/logging_optimizer.py index a48413212..27b9b3a3a 100644 --- a/cython/gtsam/utils/logging_optimizer.py +++ b/cython/gtsam/utils/logging_optimizer.py @@ -46,6 +46,7 @@ def gtsam_optimize(optimizer, def check_convergence(optimizer, current_error, new_error): return (optimizer.iterations() >= params.getMaxIterations()) or ( gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), - current_error, new_error)) or (optimizer.lambda_() > params.getlambdaUpperBound()) + current_error, new_error)) or ( + type(optimizer).__name__ == "LevenbergMarquardtOptimizer" and optimizer.lambda_() > params.getlambdaUpperBound()) optimize(optimizer, check_convergence, hook) return optimizer.values()