Change `optimize_using` to simpler function call

release/4.3a0
Gerry Chen 2022-05-13 10:24:49 -04:00
parent 1e03c8b195
commit 78c7a6b72c
No known key found for this signature in database
GPG Key ID: E9845092D3A57286
2 changed files with 30 additions and 27 deletions

View File

@ -66,9 +66,9 @@ class TestOptimizeComet(GtsamTestCase):
# Wrapper function sets the hook and calls optimizer.optimize() for us. # Wrapper function sets the hook and calls optimizer.optimize() for us.
params = gtsam.GaussNewtonParams() params = gtsam.GaussNewtonParams()
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial) actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial)
self.check(actual) self.check(actual)
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params) actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params)
self.check(actual) self.check(actual)
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params), actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
params, hook) params, hook)
@ -80,10 +80,10 @@ class TestOptimizeComet(GtsamTestCase):
print(error) print(error)
params = gtsam.LevenbergMarquardtParams() params = gtsam.LevenbergMarquardtParams()
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial) actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial)
self.check(actual) self.check(actual)
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial, actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial,
params) params)
self.check(actual) self.check(actual)
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params), actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
params, hook) params, hook)

View File

@ -17,39 +17,42 @@ OPTIMIZER_PARAMS_MAP = {
} }
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]: def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values:
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration """ Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
hook. hook.
Example usage: Example usage:
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params) ```python
def hook(optimizer, error):
print("iteration {:}, error = {:}".format(optimizer.iterations(), error))
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params)
```
Iteration hook's args are (optimizer, error) and return type should be None
Args: Args:
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer, OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
LevenbergMarquadrtOptimizer) LevenbergMarquardtOptimizer)
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer, hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
error) and return should be None. error) and return should be None.
*args: Arguments that would be passed into the OptimizerClass constructor, usually:
graph, init, [params]
Returns: Returns:
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer (gtsam.Values): A Values object representing the optimization solution.
arguments (will be forwarded to constructor) and it will return a Values object
representing the solution. See example usage above.
""" """
# Add the iteration hook to the NonlinearOptimizerParams
def wrapped_optimize(*args): for arg in args:
for arg in args: if isinstance(arg, gtsam.NonlinearOptimizerParams):
if isinstance(arg, gtsam.NonlinearOptimizerParams): arg.iterationHook = lambda iteration, error_before, error_after: hook(
arg.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
break
else:
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
params.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after) optimizer, error_after)
args = (*args, params) break
optimizer = OptimizerClass(*args) else:
hook(optimizer, optimizer.error()) params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
return optimizer.optimize() params.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
return wrapped_optimize args = (*args, params)
# Construct Optimizer and optimize
optimizer = OptimizerClass(*args)
hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below
return optimizer.optimize()
def optimize(optimizer, check_convergence, hook): def optimize(optimizer, check_convergence, hook):