Change `optimize_using` to simpler function call
parent
1e03c8b195
commit
78c7a6b72c
|
@ -66,9 +66,9 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
|
||||
# Wrapper function sets the hook and calls optimizer.optimize() for us.
|
||||
params = gtsam.GaussNewtonParams()
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial)
|
||||
self.check(actual)
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params)
|
||||
self.check(actual)
|
||||
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
|
||||
params, hook)
|
||||
|
@ -80,10 +80,10 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
print(error)
|
||||
|
||||
params = gtsam.LevenbergMarquardtParams()
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial)
|
||||
self.check(actual)
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
|
||||
params)
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial,
|
||||
params)
|
||||
self.check(actual)
|
||||
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
|
||||
params, hook)
|
||||
|
|
|
@ -17,39 +17,42 @@ OPTIMIZER_PARAMS_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
|
||||
def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values:
|
||||
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
|
||||
hook.
|
||||
Example usage:
|
||||
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
|
||||
```python
|
||||
def hook(optimizer, error):
|
||||
print("iteration {:}, error = {:}".format(optimizer.iterations(), error))
|
||||
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params)
|
||||
```
|
||||
Iteration hook's args are (optimizer, error) and return type should be None
|
||||
|
||||
Args:
|
||||
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
|
||||
LevenbergMarquadrtOptimizer)
|
||||
LevenbergMarquardtOptimizer)
|
||||
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
|
||||
error) and return should be None.
|
||||
*args: Arguments that would be passed into the OptimizerClass constructor, usually:
|
||||
graph, init, [params]
|
||||
Returns:
|
||||
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
|
||||
arguments (will be forwarded to constructor) and it will return a Values object
|
||||
representing the solution. See example usage above.
|
||||
(gtsam.Values): A Values object representing the optimization solution.
|
||||
"""
|
||||
|
||||
def wrapped_optimize(*args):
|
||||
for arg in args:
|
||||
if isinstance(arg, gtsam.NonlinearOptimizerParams):
|
||||
arg.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
optimizer, error_after)
|
||||
break
|
||||
else:
|
||||
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
|
||||
params.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
# Add the iteration hook to the NonlinearOptimizerParams
|
||||
for arg in args:
|
||||
if isinstance(arg, gtsam.NonlinearOptimizerParams):
|
||||
arg.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
optimizer, error_after)
|
||||
args = (*args, params)
|
||||
optimizer = OptimizerClass(*args)
|
||||
hook(optimizer, optimizer.error())
|
||||
return optimizer.optimize()
|
||||
|
||||
return wrapped_optimize
|
||||
break
|
||||
else:
|
||||
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
|
||||
params.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
optimizer, error_after)
|
||||
args = (*args, params)
|
||||
# Construct Optimizer and optimize
|
||||
optimizer = OptimizerClass(*args)
|
||||
hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below
|
||||
return optimizer.optimize()
|
||||
|
||||
|
||||
def optimize(optimizer, check_convergence, hook):
|
||||
|
|
Loading…
Reference in New Issue