Merge pull request #385 from borglab/fix/logging_lambda

Fix lambda check in logging optimizer
release/4.3a0
Fan Jiang 2020-07-12 20:56:47 -04:00 committed by GitHub
commit 038bf297f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 1 deletions

View File

@ -43,6 +43,11 @@ class TestOptimizeComet(GtsamTestCase):
self.optimizer = gtsam.GaussNewtonOptimizer( self.optimizer = gtsam.GaussNewtonOptimizer(
graph, initial, self.params) graph, initial, self.params)
self.lmparams = gtsam.LevenbergMarquardtParams()
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
graph, initial, self.lmparams
)
# setup output capture # setup output capture
self.capturedOutput = StringIO() self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput sys.stdout = self.capturedOutput
@ -65,6 +70,16 @@ class TestOptimizeComet(GtsamTestCase):
actual = self.optimizer.values() actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM"""
def hook(_, error):
print(error)
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
@unittest.skip("Not a test we want run every time, as needs comet.ml account") @unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self): def test_comet(self):
"""Test with a comet hook.""" """Test with a comet hook."""

View File

@ -46,5 +46,7 @@ def gtsam_optimize(optimizer,
def check_convergence(optimizer, current_error, new_error): def check_convergence(optimizer, current_error, new_error):
return (optimizer.iterations() >= params.getMaxIterations()) or ( return (optimizer.iterations() >= params.getMaxIterations()) or (
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
current_error, new_error)) current_error, new_error)) or (
isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
optimize(optimizer, check_convergence, hook) optimize(optimizer, check_convergence, hook)
return optimizer.values()