Create convenience wrapper function in logging_optimizer

release/4.3a0
Gerry Chen 2022-04-20 16:21:59 -04:00
parent 61eef0639a
commit 5796fe3488
2 changed files with 73 additions and 26 deletions

View File

@ -18,7 +18,7 @@ import numpy as np
from gtsam import Rot3 from gtsam import Rot3
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam.utils.logging_optimizer import gtsam_optimize from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using
KEY = 0 KEY = 0
MODEL = gtsam.noiseModel.Unit.Create(3) MODEL = gtsam.noiseModel.Unit.Create(3)
@ -34,19 +34,18 @@ class TestOptimizeComet(GtsamTestCase):
rotations = {R, R.inverse()} # mean is the identity rotations = {R, R.inverse()} # mean is the identity
self.expected = Rot3() self.expected = Rot3()
graph = gtsam.NonlinearFactorGraph() def check(actual):
for R in rotations: # Check that optimizing yields the identity
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
initial = gtsam.Values() # Check that logging output prints out 3 lines (exact intermediate values differ by OS)
initial.insert(KEY, R) self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3)
self.params = gtsam.GaussNewtonParams() self.check = check
self.optimizer = gtsam.GaussNewtonOptimizer(
graph, initial, self.params)
self.lmparams = gtsam.LevenbergMarquardtParams() self.graph = gtsam.NonlinearFactorGraph()
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer( for R in rotations:
graph, initial, self.lmparams self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
) self.initial = gtsam.Values()
self.initial.insert(KEY, R)
# setup output capture # setup output capture
self.capturedOutput = StringIO() self.capturedOutput = StringIO()
@ -64,25 +63,28 @@ class TestOptimizeComet(GtsamTestCase):
print(error) print(error)
# Wrapper function sets the hook and calls optimizer.optimize() for us. # Wrapper function sets the hook and calls optimizer.optimize() for us.
gtsam_optimize(self.optimizer, self.params, hook) params = gtsam.GaussNewtonParams()
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
# Check that optimizing yields the identity. self.check(actual)
actual = self.optimizer.values() actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) self.check(actual)
self.assertEqual(self.capturedOutput.getvalue(), actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n") params, hook)
self.check(actual)
def test_lm_simple_printing(self): def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM""" """Make sure we are properly terminating LM"""
def hook(_, error): def hook(_, error):
print(error) print(error)
gtsam_optimize(self.lmoptimizer, self.lmparams, hook) params = gtsam.LevenbergMarquardtParams()
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
actual = self.lmoptimizer.values() self.check(actual)
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
self.assertEqual(self.capturedOutput.getvalue(), params)
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n") self.check(actual)
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
params, hook)
@unittest.skip("Not a test we want run every time, as needs comet.ml account") @unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self): def test_comet(self):

View File

@ -6,6 +6,50 @@ Author: Jing Wu and Frank Dellaert
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
import gtsam import gtsam
from typing import Any, Callable
OPTIMIZER_PARAMS_MAP = {
gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams,
gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams,
gtsam.DoglegOptimizer: gtsam.DoglegParams,
gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams,
gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams
}
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
hook.
Example usage:
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
Args:
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
LevenbergMarquadrtOptimizer)
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
error) and return should be None.
Returns:
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
arguments (will be forwarded to constructor) and it will return a Values object
representing the solution. See example usage above.
"""
def wrapped_optimize(*args):
for arg in args:
if isinstance(arg, gtsam.NonlinearOptimizerParams):
arg.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
break
else:
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
params.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
args = (*args, params)
optimizer = OptimizerClass(*args)
hook(optimizer, optimizer.error())
return optimizer.optimize()
return wrapped_optimize
def optimize(optimizer, check_convergence, hook): def optimize(optimizer, check_convergence, hook):
@ -37,6 +81,7 @@ def gtsam_optimize(optimizer,
params, params,
hook): hook):
""" Given an optimizer and params, iterate until convergence. """ Given an optimizer and params, iterate until convergence.
Recommend using optimize_using instead.
After each iteration, hook(optimizer) is called. After each iteration, hook(optimizer) is called.
After the function, use values and errors to get the result. After the function, use values and errors to get the result.
Arguments: Arguments: