From 78c7a6b72cb011475e0aadeaabfe2acf247d4481 Mon Sep 17 00:00:00 2001 From: Gerry Chen Date: Fri, 13 May 2022 10:24:49 -0400 Subject: [PATCH] Change `optimize_using` to simpler function call --- python/gtsam/tests/test_logging_optimizer.py | 10 ++--- python/gtsam/utils/logging_optimizer.py | 47 +++++++++++--------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/python/gtsam/tests/test_logging_optimizer.py b/python/gtsam/tests/test_logging_optimizer.py index c58e4f121..602aeffc9 100644 --- a/python/gtsam/tests/test_logging_optimizer.py +++ b/python/gtsam/tests/test_logging_optimizer.py @@ -66,9 +66,9 @@ class TestOptimizeComet(GtsamTestCase): # Wrapper function sets the hook and calls optimizer.optimize() for us. params = gtsam.GaussNewtonParams() - actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial) + actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial) self.check(actual) - actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params) + actual = optimize_using(gtsam.GaussNewtonOptimizer, hook, self.graph, self.initial, params) self.check(actual) actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params), params, hook) @@ -80,10 +80,10 @@ class TestOptimizeComet(GtsamTestCase): print(error) params = gtsam.LevenbergMarquardtParams() - actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial) + actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial) self.check(actual) - actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial, - params) + actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook, self.graph, self.initial, + params) self.check(actual) actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params), params, hook) diff --git a/python/gtsam/utils/logging_optimizer.py b/python/gtsam/utils/logging_optimizer.py index f89208bc5..fe2f717d8 100644 --- a/python/gtsam/utils/logging_optimizer.py +++ b/python/gtsam/utils/logging_optimizer.py @@ -17,39 +17,42 @@ OPTIMIZER_PARAMS_MAP = { } -def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]: +def optimize_using(OptimizerClass, hook, *args) -> gtsam.Values: """ Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration hook. Example usage: - solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params) + ```python + def hook(optimizer, error): + print("iteration {:}, error = {:}".format(optimizer.iterations(), error)) + solution = optimize_using(gtsam.GaussNewtonOptimizer, hook, graph, init, params) + ``` + Iteration hook's args are (optimizer, error) and return type should be None Args: OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer, - LevenbergMarquadrtOptimizer) + LevenbergMarquardtOptimizer) hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer, error) and return should be None. + *args: Arguments that would be passed into the OptimizerClass constructor, usually: + graph, init, [params] Returns: - (Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer - arguments (will be forwarded to constructor) and it will return a Values object - representing the solution. See example usage above. + (gtsam.Values): A Values object representing the optimization solution. """ - - def wrapped_optimize(*args): - for arg in args: - if isinstance(arg, gtsam.NonlinearOptimizerParams): - arg.iterationHook = lambda iteration, error_before, error_after: hook( - optimizer, error_after) - break - else: - params = OPTIMIZER_PARAMS_MAP[OptimizerClass]() - params.iterationHook = lambda iteration, error_before, error_after: hook( + # Add the iteration hook to the NonlinearOptimizerParams + for arg in args: + if isinstance(arg, gtsam.NonlinearOptimizerParams): + arg.iterationHook = lambda iteration, error_before, error_after: hook( optimizer, error_after) - args = (*args, params) - optimizer = OptimizerClass(*args) - hook(optimizer, optimizer.error()) - return optimizer.optimize() - - return wrapped_optimize + break + else: + params = OPTIMIZER_PARAMS_MAP[OptimizerClass]() + params.iterationHook = lambda iteration, error_before, error_after: hook( + optimizer, error_after) + args = (*args, params) + # Construct Optimizer and optimize + optimizer = OptimizerClass(*args) + hook(optimizer, optimizer.error()) # Call hook once with init values to match behavior below + return optimizer.optimize() def optimize(optimizer, check_convergence, hook):