Create convenience wrapper function in logging_optimizer

release/4.3a0
Gerry Chen 2022-04-20 16:21:59 -04:00
parent 61eef0639a
commit 5796fe3488
2 changed files with 73 additions and 26 deletions

View File

@ -18,7 +18,7 @@ import numpy as np
from gtsam import Rot3
from gtsam.utils.test_case import GtsamTestCase
from gtsam.utils.logging_optimizer import gtsam_optimize
from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using
KEY = 0
MODEL = gtsam.noiseModel.Unit.Create(3)
@ -34,19 +34,18 @@ class TestOptimizeComet(GtsamTestCase):
rotations = {R, R.inverse()} # mean is the identity
self.expected = Rot3()
graph = gtsam.NonlinearFactorGraph()
for R in rotations:
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
initial = gtsam.Values()
initial.insert(KEY, R)
self.params = gtsam.GaussNewtonParams()
self.optimizer = gtsam.GaussNewtonOptimizer(
graph, initial, self.params)
def check(actual):
# Check that optimizing yields the identity
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
# Check that logging output prints out 3 lines (exact intermediate values differ by OS)
self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3)
self.check = check
self.lmparams = gtsam.LevenbergMarquardtParams()
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
graph, initial, self.lmparams
)
self.graph = gtsam.NonlinearFactorGraph()
for R in rotations:
self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
self.initial = gtsam.Values()
self.initial.insert(KEY, R)
# setup output capture
self.capturedOutput = StringIO()
@ -64,25 +63,28 @@ class TestOptimizeComet(GtsamTestCase):
print(error)
# Wrapper function sets the hook and calls optimizer.optimize() for us.
gtsam_optimize(self.optimizer, self.params, hook)
# Check that optimizing yields the identity.
actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
params = gtsam.GaussNewtonParams()
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
self.check(actual)
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
self.check(actual)
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
params, hook)
self.check(actual)
def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM"""
def hook(_, error):
print(error)
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
params = gtsam.LevenbergMarquardtParams()
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
self.check(actual)
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
params)
self.check(actual)
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
params, hook)
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self):

View File

@ -6,6 +6,50 @@ Author: Jing Wu and Frank Dellaert
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
import gtsam
from typing import Any, Callable
OPTIMIZER_PARAMS_MAP = {
gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams,
gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams,
gtsam.DoglegOptimizer: gtsam.DoglegParams,
gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams,
gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams
}
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
hook.
Example usage:
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
Args:
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
LevenbergMarquadrtOptimizer)
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
error) and return should be None.
Returns:
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
arguments (will be forwarded to constructor) and it will return a Values object
representing the solution. See example usage above.
"""
def wrapped_optimize(*args):
for arg in args:
if isinstance(arg, gtsam.NonlinearOptimizerParams):
arg.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
break
else:
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
params.iterationHook = lambda iteration, error_before, error_after: hook(
optimizer, error_after)
args = (*args, params)
optimizer = OptimizerClass(*args)
hook(optimizer, optimizer.error())
return optimizer.optimize()
return wrapped_optimize
def optimize(optimizer, check_convergence, hook):
@ -37,6 +81,7 @@ def gtsam_optimize(optimizer,
params,
hook):
""" Given an optimizer and params, iterate until convergence.
Recommend using optimize_using instead.
After each iteration, hook(optimizer) is called.
After the function, use values and errors to get the result.
Arguments: