Create convenience wrapper function in logging_optimizer
parent
61eef0639a
commit
5796fe3488
|
@ -18,7 +18,7 @@ import numpy as np
|
|||
from gtsam import Rot3
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam.utils.logging_optimizer import gtsam_optimize
|
||||
from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using
|
||||
|
||||
KEY = 0
|
||||
MODEL = gtsam.noiseModel.Unit.Create(3)
|
||||
|
@ -34,19 +34,18 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
rotations = {R, R.inverse()} # mean is the identity
|
||||
self.expected = Rot3()
|
||||
|
||||
graph = gtsam.NonlinearFactorGraph()
|
||||
for R in rotations:
|
||||
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
|
||||
initial = gtsam.Values()
|
||||
initial.insert(KEY, R)
|
||||
self.params = gtsam.GaussNewtonParams()
|
||||
self.optimizer = gtsam.GaussNewtonOptimizer(
|
||||
graph, initial, self.params)
|
||||
def check(actual):
|
||||
# Check that optimizing yields the identity
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
# Check that logging output prints out 3 lines (exact intermediate values differ by OS)
|
||||
self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3)
|
||||
self.check = check
|
||||
|
||||
self.lmparams = gtsam.LevenbergMarquardtParams()
|
||||
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
|
||||
graph, initial, self.lmparams
|
||||
)
|
||||
self.graph = gtsam.NonlinearFactorGraph()
|
||||
for R in rotations:
|
||||
self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
|
||||
self.initial = gtsam.Values()
|
||||
self.initial.insert(KEY, R)
|
||||
|
||||
# setup output capture
|
||||
self.capturedOutput = StringIO()
|
||||
|
@ -64,25 +63,28 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
print(error)
|
||||
|
||||
# Wrapper function sets the hook and calls optimizer.optimize() for us.
|
||||
gtsam_optimize(self.optimizer, self.params, hook)
|
||||
|
||||
# Check that optimizing yields the identity.
|
||||
actual = self.optimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
self.assertEqual(self.capturedOutput.getvalue(),
|
||||
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
|
||||
params = gtsam.GaussNewtonParams()
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
|
||||
self.check(actual)
|
||||
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
|
||||
self.check(actual)
|
||||
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
|
||||
params, hook)
|
||||
self.check(actual)
|
||||
|
||||
def test_lm_simple_printing(self):
|
||||
"""Make sure we are properly terminating LM"""
|
||||
def hook(_, error):
|
||||
print(error)
|
||||
|
||||
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
|
||||
|
||||
actual = self.lmoptimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
self.assertEqual(self.capturedOutput.getvalue(),
|
||||
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
|
||||
params = gtsam.LevenbergMarquardtParams()
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
|
||||
self.check(actual)
|
||||
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
|
||||
params)
|
||||
self.check(actual)
|
||||
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
|
||||
params, hook)
|
||||
|
||||
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
|
||||
def test_comet(self):
|
||||
|
|
|
@ -6,6 +6,50 @@ Author: Jing Wu and Frank Dellaert
|
|||
|
||||
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
|
||||
import gtsam
|
||||
from typing import Any, Callable
|
||||
|
||||
OPTIMIZER_PARAMS_MAP = {
|
||||
gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams,
|
||||
gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams,
|
||||
gtsam.DoglegOptimizer: gtsam.DoglegParams,
|
||||
gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams,
|
||||
gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams
|
||||
}
|
||||
|
||||
|
||||
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
|
||||
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
|
||||
hook.
|
||||
Example usage:
|
||||
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
|
||||
|
||||
Args:
|
||||
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
|
||||
LevenbergMarquadrtOptimizer)
|
||||
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
|
||||
error) and return should be None.
|
||||
Returns:
|
||||
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
|
||||
arguments (will be forwarded to constructor) and it will return a Values object
|
||||
representing the solution. See example usage above.
|
||||
"""
|
||||
|
||||
def wrapped_optimize(*args):
|
||||
for arg in args:
|
||||
if isinstance(arg, gtsam.NonlinearOptimizerParams):
|
||||
arg.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
optimizer, error_after)
|
||||
break
|
||||
else:
|
||||
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
|
||||
params.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||
optimizer, error_after)
|
||||
args = (*args, params)
|
||||
optimizer = OptimizerClass(*args)
|
||||
hook(optimizer, optimizer.error())
|
||||
return optimizer.optimize()
|
||||
|
||||
return wrapped_optimize
|
||||
|
||||
|
||||
def optimize(optimizer, check_convergence, hook):
|
||||
|
@ -37,6 +81,7 @@ def gtsam_optimize(optimizer,
|
|||
params,
|
||||
hook):
|
||||
""" Given an optimizer and params, iterate until convergence.
|
||||
Recommend using optimize_using instead.
|
||||
After each iteration, hook(optimizer) is called.
|
||||
After the function, use values and errors to get the result.
|
||||
Arguments:
|
||||
|
|
Loading…
Reference in New Issue