From fbe9a21070c90454bfdf9e3c5dc43d6387e64ffc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 31 Jan 2022 18:55:40 -0500 Subject: [PATCH] attempt to get custom factor tests passing --- gtsam/nonlinear/CustomFactor.cpp | 16 +++++++++---- gtsam/nonlinear/CustomFactor.h | 5 ++-- python/gtsam/tests/test_custom_factor.py | 29 ++++++++++++------------ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/gtsam/nonlinear/CustomFactor.cpp b/gtsam/nonlinear/CustomFactor.cpp index e33caed6f..b8368c497 100644 --- a/gtsam/nonlinear/CustomFactor.cpp +++ b/gtsam/nonlinear/CustomFactor.cpp @@ -22,13 +22,14 @@ namespace gtsam { /* * Calculates the unwhitened error by invoking the callback functor (i.e. from Python). */ -Vector CustomFactor::unwhitenedError(const Values& x, boost::optional&> H) const { +Vector CustomFactor::unwhitenedError( + const Values &x, boost::optional&> H) const { if(this->active(x)) { if(H) { /* * In this case, we pass the raw pointer to the `std::vector` object directly to pybind. - * As the type `std::vector` has been marked as opaque in `preamble.h`, any changes in + * As the type `std::vector` has been marked as opaque in `preamble/base.h`, any changes in * Python will be immediately reflected on the C++ side. * * Example: @@ -43,13 +44,20 @@ Vector CustomFactor::unwhitenedError(const Values& x, boost::optionalerror_function_(*this, x, H.get_ptr()); + std::pair errorAndJacobian = + this->error_function_(*this, x, H.get_ptr()); + + Vector error = errorAndJacobian.first; + (*H) = errorAndJacobian.second; + + return error; } else { /* * In this case, we pass the a `nullptr` to pybind, and it will translate to `None` in Python. * Users can check for `None` in their callback to determine if the Jacobian is requested. */ - return this->error_function_(*this, x, nullptr); + auto errorAndJacobian = this->error_function_(*this, x, nullptr); + return errorAndJacobian.first; } } else { return Vector::Zero(this->dim()); diff --git a/gtsam/nonlinear/CustomFactor.h b/gtsam/nonlinear/CustomFactor.h index 615b5418e..6261636b5 100644 --- a/gtsam/nonlinear/CustomFactor.h +++ b/gtsam/nonlinear/CustomFactor.h @@ -35,7 +35,8 @@ class CustomFactor; * This is safe because this is passing a const pointer, and pybind11 will maintain the `std::vector` memory layout. * Thus the pointer will never be invalidated. */ -using CustomErrorFunction = std::function; +using CustomErrorFunction = std::function( + const CustomFactor &, const Values &, JacobianVector *)>; /** * @brief Custom factor that takes a std::function as the error @@ -77,7 +78,7 @@ public: * Calls the errorFunction closure, which is a std::function object * One can check if a derivative is needed in the errorFunction by checking the length of Jacobian array */ - Vector unwhitenedError(const Values &x, boost::optional &> H = boost::none) const override; + Vector unwhitenedError(const Values &x, boost::optional&> H = boost::none) const override; /** print */ void print(const std::string &s, diff --git a/python/gtsam/tests/test_custom_factor.py b/python/gtsam/tests/test_custom_factor.py index 4f0f33361..03e6917f0 100644 --- a/python/gtsam/tests/test_custom_factor.py +++ b/python/gtsam/tests/test_custom_factor.py @@ -8,13 +8,12 @@ See LICENSE for the license information CustomFactor unit tests. Author: Fan Jiang """ -from typing import List import unittest -from gtsam import Values, Pose2, CustomFactor - -import numpy as np +from typing import List import gtsam +import numpy as np +from gtsam import CustomFactor, JacobianFactor, Pose2, Values from gtsam.utils.test_case import GtsamTestCase @@ -24,17 +23,17 @@ class TestCustomFactor(GtsamTestCase): def error_func(this: CustomFactor, v: gtsam.Values, H: List[np.ndarray]): """Minimal error function stub""" - return np.array([1, 0, 0]) + return np.array([1, 0, 0]), H noise_model = gtsam.noiseModel.Unit.Create(3) - cf = CustomFactor(noise_model, gtsam.KeyVector([0]), error_func) + cf = CustomFactor(noise_model, [0], error_func) def test_new_keylist(self): """Test the creation of a new CustomFactor""" def error_func(this: CustomFactor, v: gtsam.Values, H: List[np.ndarray]): """Minimal error function stub""" - return np.array([1, 0, 0]) + return np.array([1, 0, 0]), H noise_model = gtsam.noiseModel.Unit.Create(3) cf = CustomFactor(noise_model, [0], error_func) @@ -47,7 +46,7 @@ class TestCustomFactor(GtsamTestCase): """Minimal error function with no Jacobian""" key0 = this.keys()[0] error = -v.atPose2(key0).localCoordinates(expected_pose) - return error + return error, H noise_model = gtsam.noiseModel.Unit.Create(3) cf = CustomFactor(noise_model, [0], error_func) @@ -81,10 +80,10 @@ class TestCustomFactor(GtsamTestCase): result = gT1.between(gT2) H[0] = -result.inverse().AdjointMap() H[1] = np.eye(3) - return error + return error, H noise_model = gtsam.noiseModel.Unit.Create(3) - cf = CustomFactor(noise_model, gtsam.KeyVector([0, 1]), error_func) + cf = CustomFactor(noise_model, [0, 1], error_func) v = Values() v.insert(0, gT1) v.insert(1, gT2) @@ -104,9 +103,9 @@ class TestCustomFactor(GtsamTestCase): gT1 = Pose2(1, 2, np.pi / 2) gT2 = Pose2(-1, 4, np.pi) - def error_func(this: CustomFactor, v: gtsam.Values, _: List[np.ndarray]): + def error_func(this: CustomFactor, v: gtsam.Values, H: List[np.ndarray]): """Minimal error function stub""" - return np.array([1, 0, 0]) + return np.array([1, 0, 0]), H noise_model = gtsam.noiseModel.Unit.Create(3) from gtsam.symbol_shorthand import X @@ -144,10 +143,10 @@ class TestCustomFactor(GtsamTestCase): result = gT1.between(gT2) H[0] = -result.inverse().AdjointMap() H[1] = np.eye(3) - return error + return error, H noise_model = gtsam.noiseModel.Unit.Create(3) - cf = CustomFactor(noise_model, gtsam.KeyVector([0, 1]), error_func) + cf = CustomFactor(noise_model, [0, 1], error_func) v = Values() v.insert(0, gT1) v.insert(1, gT2) @@ -182,7 +181,7 @@ class TestCustomFactor(GtsamTestCase): result = gT1.between(gT2) H[0] = -result.inverse().AdjointMap() H[1] = np.eye(3) - return error + return error, H noise_model = gtsam.noiseModel.Unit.Create(3) cf = CustomFactor(noise_model, [0, 1], error_func)