Change `optimize_using` to simpler function call

release/4.3a0
Gerry Chen 2022-05-13 10:24:49 -04:00
parent 1e03c8b195
commit 78c7a6b72c
No known key found for this signature in database
GPG Key ID: E9845092D3A57286
2 changed files with 30 additions and 27 deletions

View File

@ -66,9 +66,9 @@ class TestOptimizeComet(GtsamTestCase):
# Wrapper function sets the hook and calls optimizer.optimize() for us. # Wrapper function sets the hook and calls optimizer.optimize() for us.
params = gtsam.GaussNewtonParams() params = gtsam.GaussNewtonParams()
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial) actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial)
self.check(actual) self.check(actual)
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params) actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params)
self.check(actual) self.check(actual)
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params), actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
params, hook) params, hook)
@ -80,9 +80,9 @@ class TestOptimizeComet(GtsamTestCase):
print(error) print(error)
params = gtsam.LevenbergMarquardtParams() params = gtsam.LevenbergMarquardtParams()
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial) actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial)
self.check(actual) self.check(actual)
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial, actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial,
params) params)
self.check(actual) self.check(actual)
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params), actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),

View File

@ -17,24 +17,28 @@ OPTIMIZER_PARAMS_MAP = {
} }
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]: def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values:
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration """ Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
hook. hook.
Example usage: Example usage:
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params) ```python
def hook(optimizer, error):
print("iteration {:}, error = {:}".format(optimizer.iterations(), error))
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params)
```
Iteration hook's args are (optimizer, error) and return type should be None
Args: Args:
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer, OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
LevenbergMarquadrtOptimizer) LevenbergMarquardtOptimizer)
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer, hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
error) and return should be None. error) and return should be None.
*args: Arguments that would be passed into the OptimizerClass constructor, usually:
graph, init, [params]
Returns: Returns:
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer (gtsam.Values): A Values object representing the optimization solution.
arguments (will be forwarded to constructor) and it will return a Values object
representing the solution. See example usage above.
""" """
# Add the iteration hook to the NonlinearOptimizerParams
def wrapped_optimize(*args):
for arg in args: for arg in args:
if isinstance(arg, gtsam.NonlinearOptimizerParams): if isinstance(arg, gtsam.NonlinearOptimizerParams):
arg.iterationHook = lambda iteration, error_before, error_after: hook( arg.iterationHook = lambda iteration, error_before, error_after: hook(
@ -45,12 +49,11 @@ def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
params.iterationHook = lambda iteration, error_before, error_after: hook( params.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after) optimizer, error_after)
args = (*args, params) args = (*args, params)
# Construct Optimizer and optimize
optimizer = OptimizerClass(*args) optimizer = OptimizerClass(*args)
hook(optimizer, optimizer.error()) hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below
return optimizer.optimize() return optimizer.optimize()
return wrapped_optimize
def optimize(optimizer, check_convergence, hook): def optimize(optimizer, check_convergence, hook):
""" Given an optimizer and a convergence check, iterate until convergence. """ Given an optimizer and a convergence check, iterate until convergence.