Merge pull request #345 from borglab/feature/logging_optimizer
Add logging (hooked) optimizerrelease/4.3a0
commit
18e80b83aa
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Unit tests for optimization that logs to comet.ml.
|
||||
Author: Jing Wu and Frank Dellaert
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
import gtsam
|
||||
import numpy as np
|
||||
from gtsam import Rot3
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam.utils.logging_optimizer import gtsam_optimize
|
||||
|
||||
KEY = 0
|
||||
MODEL = gtsam.noiseModel_Unit.Create(3)
|
||||
|
||||
|
||||
class TestOptimizeComet(GtsamTestCase):
|
||||
"""Check correct logging to comet.ml."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up a small Karcher mean optimization example."""
|
||||
# Grabbed from KarcherMeanFactor unit tests.
|
||||
R = Rot3.Expmap(np.array([0.1, 0, 0]))
|
||||
rotations = {R, R.inverse()} # mean is the identity
|
||||
self.expected = Rot3()
|
||||
|
||||
graph = gtsam.NonlinearFactorGraph()
|
||||
for R in rotations:
|
||||
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
|
||||
initial = gtsam.Values()
|
||||
initial.insert(KEY, R)
|
||||
self.params = gtsam.GaussNewtonParams()
|
||||
self.optimizer = gtsam.GaussNewtonOptimizer(
|
||||
graph, initial, self.params)
|
||||
|
||||
def test_simple_printing(self):
|
||||
"""Test with a simple hook."""
|
||||
|
||||
# Provide a hook that just prints
|
||||
def hook(_, error: float):
|
||||
print(error)
|
||||
|
||||
# Only thing we require from optimizer is an iterate method
|
||||
gtsam_optimize(self.optimizer, self.params, hook)
|
||||
|
||||
# Check that optimizing yields the identity.
|
||||
actual = self.optimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
|
||||
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
|
||||
def test_comet(self):
|
||||
"""Test with a comet hook."""
|
||||
from comet_ml import Experiment
|
||||
comet = Experiment(project_name="Testing",
|
||||
auto_output_logging="native")
|
||||
comet.log_dataset_info(name="Karcher", path="shonan")
|
||||
comet.add_tag("GaussNewton")
|
||||
comet.log_parameter("method", "GaussNewton")
|
||||
time = datetime.now()
|
||||
comet.set_name("GaussNewton-" + str(time.month) + "/" + str(time.day) + " "
|
||||
+ str(time.hour)+":"+str(time.minute)+":"+str(time.second))
|
||||
|
||||
# I want to do some comet thing here
|
||||
def hook(optimizer, error: float):
|
||||
comet.log_metric("Karcher error",
|
||||
error, optimizer.iterations())
|
||||
|
||||
gtsam_optimize(self.optimizer, self.params, hook)
|
||||
comet.end()
|
||||
|
||||
actual = self.optimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Optimization with logging via a hook.
|
||||
Author: Jing Wu and Frank Dellaert
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
|
||||
import gtsam
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def optimize(optimizer: T, check_convergence, hook):
|
||||
""" Given an optimizer and a convergence check, iterate until convergence.
|
||||
After each iteration, hook(optimizer, error) is called.
|
||||
After the function, use values and errors to get the result.
|
||||
Arguments:
|
||||
optimizer (T): needs an iterate and an error function.
|
||||
check_convergence: T * float * float -> bool
|
||||
hook -- hook function to record the error
|
||||
"""
|
||||
# the optimizer is created with default values which incur the error below
|
||||
current_error = optimizer.error()
|
||||
hook(optimizer, current_error)
|
||||
|
||||
# Iterative loop
|
||||
while True:
|
||||
# Do next iteration
|
||||
optimizer.iterate()
|
||||
new_error = optimizer.error()
|
||||
hook(optimizer, new_error)
|
||||
if check_convergence(optimizer, current_error, new_error):
|
||||
return
|
||||
current_error = new_error
|
||||
|
||||
|
||||
def gtsam_optimize(optimizer: NonlinearOptimizer,
|
||||
params: NonlinearOptimizerParams,
|
||||
hook):
|
||||
""" Given an optimizer and params, iterate until convergence.
|
||||
After each iteration, hook(optimizer) is called.
|
||||
After the function, use values and errors to get the result.
|
||||
Arguments:
|
||||
optimizer {NonlinearOptimizer} -- Nonlinear optimizer
|
||||
params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
|
||||
hook -- hook function to record the error
|
||||
"""
|
||||
def check_convergence(optimizer, current_error, new_error):
|
||||
return (optimizer.iterations() >= params.getMaxIterations()) or (
|
||||
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
|
||||
current_error, new_error))
|
||||
optimize(optimizer, check_convergence, hook)
|
Loading…
Reference in New Issue