Create convenience wrapper function in logging_optimizer
parent
61eef0639a
commit
5796fe3488
|
@ -18,7 +18,7 @@ import numpy as np
|
||||||
from gtsam import Rot3
|
from gtsam import Rot3
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
from gtsam.utils.logging_optimizer import gtsam_optimize
|
from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using
|
||||||
|
|
||||||
KEY = 0
|
KEY = 0
|
||||||
MODEL = gtsam.noiseModel.Unit.Create(3)
|
MODEL = gtsam.noiseModel.Unit.Create(3)
|
||||||
|
@ -34,19 +34,18 @@ class TestOptimizeComet(GtsamTestCase):
|
||||||
rotations = {R, R.inverse()} # mean is the identity
|
rotations = {R, R.inverse()} # mean is the identity
|
||||||
self.expected = Rot3()
|
self.expected = Rot3()
|
||||||
|
|
||||||
graph = gtsam.NonlinearFactorGraph()
|
def check(actual):
|
||||||
for R in rotations:
|
# Check that optimizing yields the identity
|
||||||
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
|
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||||
initial = gtsam.Values()
|
# Check that logging output prints out 3 lines (exact intermediate values differ by OS)
|
||||||
initial.insert(KEY, R)
|
self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3)
|
||||||
self.params = gtsam.GaussNewtonParams()
|
self.check = check
|
||||||
self.optimizer = gtsam.GaussNewtonOptimizer(
|
|
||||||
graph, initial, self.params)
|
|
||||||
|
|
||||||
self.lmparams = gtsam.LevenbergMarquardtParams()
|
self.graph = gtsam.NonlinearFactorGraph()
|
||||||
self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer(
|
for R in rotations:
|
||||||
graph, initial, self.lmparams
|
self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
|
||||||
)
|
self.initial = gtsam.Values()
|
||||||
|
self.initial.insert(KEY, R)
|
||||||
|
|
||||||
# setup output capture
|
# setup output capture
|
||||||
self.capturedOutput = StringIO()
|
self.capturedOutput = StringIO()
|
||||||
|
@ -64,25 +63,28 @@ class TestOptimizeComet(GtsamTestCase):
|
||||||
print(error)
|
print(error)
|
||||||
|
|
||||||
# Wrapper function sets the hook and calls optimizer.optimize() for us.
|
# Wrapper function sets the hook and calls optimizer.optimize() for us.
|
||||||
gtsam_optimize(self.optimizer, self.params, hook)
|
params = gtsam.GaussNewtonParams()
|
||||||
|
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial)
|
||||||
# Check that optimizing yields the identity.
|
self.check(actual)
|
||||||
actual = self.optimizer.values()
|
actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params)
|
||||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
self.check(actual)
|
||||||
self.assertEqual(self.capturedOutput.getvalue(),
|
actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params),
|
||||||
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
|
params, hook)
|
||||||
|
self.check(actual)
|
||||||
|
|
||||||
def test_lm_simple_printing(self):
|
def test_lm_simple_printing(self):
|
||||||
"""Make sure we are properly terminating LM"""
|
"""Make sure we are properly terminating LM"""
|
||||||
def hook(_, error):
|
def hook(_, error):
|
||||||
print(error)
|
print(error)
|
||||||
|
|
||||||
gtsam_optimize(self.lmoptimizer, self.lmparams, hook)
|
params = gtsam.LevenbergMarquardtParams()
|
||||||
|
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial)
|
||||||
actual = self.lmoptimizer.values()
|
self.check(actual)
|
||||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial,
|
||||||
self.assertEqual(self.capturedOutput.getvalue(),
|
params)
|
||||||
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
|
self.check(actual)
|
||||||
|
actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params),
|
||||||
|
params, hook)
|
||||||
|
|
||||||
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
|
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
|
||||||
def test_comet(self):
|
def test_comet(self):
|
||||||
|
|
|
@ -6,6 +6,50 @@ Author: Jing Wu and Frank Dellaert
|
||||||
|
|
||||||
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
|
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
|
||||||
import gtsam
|
import gtsam
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
OPTIMIZER_PARAMS_MAP = {
|
||||||
|
gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams,
|
||||||
|
gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams,
|
||||||
|
gtsam.DoglegOptimizer: gtsam.DoglegParams,
|
||||||
|
gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams,
|
||||||
|
gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]:
|
||||||
|
""" Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration
|
||||||
|
hook.
|
||||||
|
Example usage:
|
||||||
|
solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer,
|
||||||
|
LevenbergMarquadrtOptimizer)
|
||||||
|
hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer,
|
||||||
|
error) and return should be None.
|
||||||
|
Returns:
|
||||||
|
(Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer
|
||||||
|
arguments (will be forwarded to constructor) and it will return a Values object
|
||||||
|
representing the solution. See example usage above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapped_optimize(*args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, gtsam.NonlinearOptimizerParams):
|
||||||
|
arg.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||||
|
optimizer, error_after)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
params = OPTIMIZER_PARAMS_MAP[OptimizerClass]()
|
||||||
|
params.iterationHook = lambda iteration, error_before, error_after: hook(
|
||||||
|
optimizer, error_after)
|
||||||
|
args = (*args, params)
|
||||||
|
optimizer = OptimizerClass(*args)
|
||||||
|
hook(optimizer, optimizer.error())
|
||||||
|
return optimizer.optimize()
|
||||||
|
|
||||||
|
return wrapped_optimize
|
||||||
|
|
||||||
|
|
||||||
def optimize(optimizer, check_convergence, hook):
|
def optimize(optimizer, check_convergence, hook):
|
||||||
|
@ -37,6 +81,7 @@ def gtsam_optimize(optimizer,
|
||||||
params,
|
params,
|
||||||
hook):
|
hook):
|
||||||
""" Given an optimizer and params, iterate until convergence.
|
""" Given an optimizer and params, iterate until convergence.
|
||||||
|
Recommend using optimize_using instead.
|
||||||
After each iteration, hook(optimizer) is called.
|
After each iteration, hook(optimizer) is called.
|
||||||
After the function, use values and errors to get the result.
|
After the function, use values and errors to get the result.
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
Loading…
Reference in New Issue