diff --git a/cython/gtsam/tests/test_logging_optimizer.py b/cython/gtsam/tests/test_logging_optimizer.py index 69665db65..2560a72a2 100644 --- a/cython/gtsam/tests/test_logging_optimizer.py +++ b/cython/gtsam/tests/test_logging_optimizer.py @@ -43,6 +43,11 @@ class TestOptimizeComet(GtsamTestCase): self.optimizer = gtsam.GaussNewtonOptimizer( graph, initial, self.params) + self.lmparams = gtsam.LevenbergMarquardtParams() + self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer( + graph, initial, self.lmparams + ) + # setup output capture self.capturedOutput = StringIO() sys.stdout = self.capturedOutput @@ -65,6 +70,16 @@ class TestOptimizeComet(GtsamTestCase): actual = self.optimizer.values() self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + def test_lm_simple_printing(self): + """Make sure we are properly terminating LM""" + def hook(_, error): + print(error) + + gtsam_optimize(self.lmoptimizer, self.lmparams, hook) + + actual = self.lmoptimizer.values() + self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + @unittest.skip("Not a test we want run every time, as needs comet.ml account") def test_comet(self): """Test with a comet hook.""" diff --git a/cython/gtsam/utils/logging_optimizer.py b/cython/gtsam/utils/logging_optimizer.py index a48413212..27b9b3a3a 100644 --- a/cython/gtsam/utils/logging_optimizer.py +++ b/cython/gtsam/utils/logging_optimizer.py @@ -46,6 +46,7 @@ def gtsam_optimize(optimizer, def check_convergence(optimizer, current_error, new_error): return (optimizer.iterations() >= params.getMaxIterations()) or ( gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), - current_error, new_error)) or (optimizer.lambda_() > params.getlambdaUpperBound()) + current_error, new_error)) or ( + type(optimizer).__name__ == "LevenbergMarquardtOptimizer" and optimizer.lambda_() > params.getlambdaUpperBound()) optimize(optimizer, check_convergence, hook) return optimizer.values()