Add check to ensure we are calling lambda on a LM

release/4.3a0
Fan Jiang 2020-07-11 11:54:40 -04:00
parent 686ea137e7
commit 8e5f1447e3
2 changed files with 17 additions and 1 deletions

View File

@ -43,6 +43,11 @@ class TestOptimizeComet(GtsamTestCase):
self.optimizer = gtsam.GaussNewtonOptimizer(
graph, initial, self.params)
self.lmparams = gtsam.LevenbergMarquardtParams()
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
graph, initial, self.lmparams
)
# setup output capture
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput
@ -65,6 +70,16 @@ class TestOptimizeComet(GtsamTestCase):
actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM"""
def hook(_, error):
print(error)
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self):
"""Test with a comet hook."""

View File

@ -46,6 +46,7 @@ def gtsam_optimize(optimizer,
def check_convergence(optimizer, current_error, new_error):
return (optimizer.iterations() >= params.getMaxIterations()) or (
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
current_error, new_error)) or (optimizer.lambda_() > params.getlambdaUpperBound())
current_error, new_error)) or (
type(optimizer).__name__ == "LevenbergMarquardtOptimizer" and optimizer.lambda_() > params.getlambdaUpperBound())
optimize(optimizer, check_convergence, hook)
return optimizer.values()