diff --git a/cmake/HandleUninstall.cmake b/cmake/HandleUninstall.cmake index 1859b0273..dccb1905e 100644 --- a/cmake/HandleUninstall.cmake +++ b/cmake/HandleUninstall.cmake @@ -6,5 +6,11 @@ configure_file( "${CMAKE_CURRENT_BINARY_DIR}/cmake_uninstall.cmake" IMMEDIATE @ONLY) -add_custom_target(uninstall - "${CMAKE_COMMAND}" -P "${CMAKE_CURRENT_BINARY_DIR}/cmake_uninstall.cmake") +if (NOT TARGET uninstall) # avoid duplicating this target + add_custom_target(uninstall + "${CMAKE_COMMAND}" -P "${CMAKE_CURRENT_BINARY_DIR}/cmake_uninstall.cmake") +else() + add_custom_target(uninstall_gtsam + "${CMAKE_COMMAND}" -P "${CMAKE_CURRENT_BINARY_DIR}/cmake_uninstall.cmake") + add_dependencies(uninstall uninstall_gtsam) +endif() diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 92a7b1834..2ca57504c 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -105,6 +105,9 @@ * virtual class MyFactor : gtsam::NoiseModelFactor {...}; * - *DO NOT* re-define overriden function already declared in the external (forward-declared) base class * - This will cause an ambiguity problem in Pybind header file + * Pickle support in Python: + * - Add "void pickle()" to a class to enable pickling via gtwrap. In the current implementation, "void serialize()" + * and a public constructor with no-arguments in needed for successful build. */ /** @@ -144,6 +147,9 @@ class KeyList { void remove(size_t key); void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a FastSet @@ -169,6 +175,9 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a vector @@ -190,6 +199,9 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a FastMap @@ -361,6 +373,9 @@ class Point2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; // std::vector @@ -422,6 +437,9 @@ class StereoPoint2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -446,6 +464,9 @@ class Point3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -501,6 +522,9 @@ class Rot2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -663,6 +687,9 @@ class Rot3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -718,6 +745,9 @@ class Pose2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -774,6 +804,9 @@ class Pose3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -817,6 +850,15 @@ class Unit3 { size_t dim() const; gtsam::Unit3 retract(Vector v) const; Vector localCoordinates(const gtsam::Unit3& s) const; + + // enabling serialization functionality + void serialize() const; + + // enable pickling in python + void pickle() const; + + // enabling function to compare objects + bool equals(const gtsam::Unit3& expected, double tol) const; }; #include @@ -876,6 +918,9 @@ class Cal3_S2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -904,6 +949,9 @@ virtual class Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -925,6 +973,9 @@ virtual class Cal3DS2 : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -951,6 +1002,9 @@ virtual class Cal3Unified : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1008,6 +1062,9 @@ class Cal3Bundler { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1038,6 +1095,9 @@ class CalibratedCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1076,6 +1136,9 @@ class PinholeCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; @@ -1146,6 +1209,9 @@ class StereoCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1600,6 +1666,9 @@ class VectorValues { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1661,6 +1730,9 @@ virtual class JacobianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1692,6 +1764,9 @@ virtual class HessianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1771,6 +1846,9 @@ class GaussianFactorGraph { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2076,6 +2154,9 @@ class Ordering { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2114,6 +2195,10 @@ class NonlinearFactorGraph { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; + void saveGraph(const string& s) const; }; @@ -2171,6 +2256,9 @@ class Values { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // New in 4.0, we have to specialize every insert/update/at to generate wrappers // Instead of the old: // void insert(size_t j, const gtsam::Value& value); @@ -2570,6 +2658,9 @@ virtual class PriorFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; @@ -2581,6 +2672,9 @@ virtual class BetweenFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2817,6 +2911,9 @@ class SfmTrack { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // enabling function to compare objects bool equals(const gtsam::SfmTrack& expected, double tol) const; }; @@ -2833,6 +2930,9 @@ class SfmData { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // enabling function to compare objects bool equals(const gtsam::SfmData& expected, double tol) const; }; diff --git a/gtsam/sfm/TranslationRecovery.cpp b/gtsam/sfm/TranslationRecovery.cpp index d4100b00a..f38c14ba7 100644 --- a/gtsam/sfm/TranslationRecovery.cpp +++ b/gtsam/sfm/TranslationRecovery.cpp @@ -81,6 +81,7 @@ void TranslationRecovery::addPrior( const double scale, NonlinearFactorGraph *graph, const SharedNoiseModel &priorNoiseModel) const { auto edge = relativeTranslations_.begin(); + if (edge == relativeTranslations_.end()) return; graph->emplace_shared >(edge->key1(), Point3(0, 0, 0), priorNoiseModel); graph->emplace_shared >( @@ -102,6 +103,15 @@ Values TranslationRecovery::initalizeRandomly() const { insert(edge.key1()); insert(edge.key2()); } + + // If there are no valid edges, but zero-distance edges exist, initialize one + // of the nodes in a connected component of zero-distance edges. + if (initial.empty() && !sameTranslationNodes_.empty()) { + for (const auto &optimizedAndDuplicateKeys : sameTranslationNodes_) { + Key optimizedKey = optimizedAndDuplicateKeys.first; + initial.insert(optimizedKey, Point3(0, 0, 0)); + } + } return initial; } diff --git a/python/gtsam/tests/test_pickle.py b/python/gtsam/tests/test_pickle.py new file mode 100644 index 000000000..0acbf6765 --- /dev/null +++ b/python/gtsam/tests/test_pickle.py @@ -0,0 +1,46 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests to check pickling. + +Author: Ayush Baid +""" +from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3 + +from gtsam.utils.test_case import GtsamTestCase + +class TestPickle(GtsamTestCase): + """Tests pickling on some of the classes.""" + + def test_cal3Bundler_roundtrip(self): + obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_pinholeCameraCal3Bundler_roundtrip(self): + obj = PinholeCameraCal3Bundler( + Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)), + Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70), + ) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_rot3_roundtrip(self): + obj = Rot3.RzRyRx(0, 0.05, 0.1) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_pose3_roundtrip(self): + obj = Pose3(Rot3.Ypr(0.0, 1.0, 0.0), Point3(1, 1, 0)) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_sfmTrack_roundtrip(self): + obj = SfmTrack(Point3(1, 1, 0)) + obj.add_measurement(0, Point2(-1, 5)) + obj.add_measurement(1, Point2(6, 2)) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_unit3_roundtrip(self): + obj = Unit3(Point3(1, 1, 0)) + self.assertEqualityOnPickleRoundtrip(obj) diff --git a/python/gtsam/utils/test_case.py b/python/gtsam/utils/test_case.py index 3effd7f65..50af004f4 100644 --- a/python/gtsam/utils/test_case.py +++ b/python/gtsam/utils/test_case.py @@ -8,6 +8,7 @@ See LICENSE for the license information TestCase class with GTSAM assert utils. Author: Frank Dellaert """ +import pickle import unittest @@ -29,3 +30,14 @@ class GtsamTestCase(unittest.TestCase): if not equal: raise self.failureException( "Values are not equal:\n{}!={}".format(actual, expected)) + + def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None: + """ Performs a round-trip using pickle and asserts equality. + + Usage: + self.assertEqualityOnPickleRoundtrip(obj) + Keyword Arguments: + tol {float} -- tolerance passed to 'equals', default 1e-9 + """ + roundTripObj = pickle.loads(pickle.dumps(obj)) + self.gtsamAssertEquals(roundTripObj, obj) diff --git a/tests/testTranslationRecovery.cpp b/tests/testTranslationRecovery.cpp index 7260fd5af..2915a375e 100644 --- a/tests/testTranslationRecovery.cpp +++ b/tests/testTranslationRecovery.cpp @@ -17,7 +17,6 @@ */ #include - #include #include @@ -185,7 +184,7 @@ TEST(TranslationRecovery, ThreePosesIncludingZeroTranslation) { TranslationRecovery algorithm(relativeTranslations); const auto graph = algorithm.buildGraph(); // There is only 1 non-zero translation edge. - EXPECT_LONGS_EQUAL(1, graph.size()); + EXPECT_LONGS_EQUAL(1, graph.size()); // Run translation recovery const auto result = algorithm.run(/*scale=*/3.0); @@ -238,6 +237,35 @@ TEST(TranslationRecovery, FourPosesIncludingZeroTranslation) { EXPECT(assert_equal(Point3(2, -2, 0), result.at(3))); } +TEST(TranslationRecovery, ThreePosesWithZeroTranslation) { + Values poses; + poses.insert(0, Pose3(Rot3::RzRyRx(-M_PI / 6, 0, 0), Point3(0, 0, 0))); + poses.insert(1, Pose3(Rot3(), Point3(0, 0, 0))); + poses.insert(2, Pose3(Rot3::RzRyRx(M_PI / 6, 0, 0), Point3(0, 0, 0))); + + auto relativeTranslations = TranslationRecovery::SimulateMeasurements( + poses, {{0, 1}, {1, 2}, {2, 0}}); + + // Check simulated measurements. + for (auto& unitTranslation : relativeTranslations) { + EXPECT(assert_equal(GetDirectionFromPoses(poses, unitTranslation), + unitTranslation.measured())); + } + + TranslationRecovery algorithm(relativeTranslations); + const auto graph = algorithm.buildGraph(); + // Graph size will be zero as there no 'non-zero distance' edges. + EXPECT_LONGS_EQUAL(0, graph.size()); + + // Run translation recovery + const auto result = algorithm.run(/*scale=*/4.0); + + // Check result + EXPECT(assert_equal(Point3(0, 0, 0), result.at(0))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(1))); + EXPECT(assert_equal(Point3(0, 0, 0), result.at(2))); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/wrap/gtwrap/matlab_wrapper.py b/wrap/gtwrap/matlab_wrapper.py index 669bf474f..fe4ee7e19 100755 --- a/wrap/gtwrap/matlab_wrapper.py +++ b/wrap/gtwrap/matlab_wrapper.py @@ -49,6 +49,8 @@ class MatlabWrapper(object): } """Methods that should not be wrapped directly""" whitelist = ['serializable', 'serialize'] + """Methods that should be ignored""" + ignore_methods = ['pickle'] """Datatypes that do not need to be checked in methods""" not_check_type = [] """Data types that are primitive types""" @@ -563,6 +565,8 @@ class MatlabWrapper(object): for method in methods: if method.name in self.whitelist: continue + if method.name in self.ignore_methods: + continue comment += '%{name}({args})'.format(name=method.name, args=self._wrap_args(method.args)) @@ -612,6 +616,9 @@ class MatlabWrapper(object): methods = self._group_methods(methods) for method in methods: + if method in self.ignore_methods: + continue + if globals: self._debug("[wrap_methods] wrapping: {}..{}={}".format(method[0].parent.name, method[0].name, type(method[0].parent.name))) @@ -861,6 +868,8 @@ class MatlabWrapper(object): method_name = method[0].name if method_name in self.whitelist and method_name != 'serialize': continue + if method_name in self.ignore_methods: + continue if method_name == 'serialize': serialize[0] = True @@ -932,6 +941,9 @@ class MatlabWrapper(object): format_name = list(static_method[0].name) format_name[0] = format_name[0].upper() + if static_method[0].name in self.ignore_methods: + continue + method_text += textwrap.indent(textwrap.dedent('''\ function varargout = {name}(varargin) '''.format(name=''.join(format_name))), diff --git a/wrap/gtwrap/pybind_wrapper.py b/wrap/gtwrap/pybind_wrapper.py index c0e88e37a..a045afcbd 100755 --- a/wrap/gtwrap/pybind_wrapper.py +++ b/wrap/gtwrap/pybind_wrapper.py @@ -76,6 +76,21 @@ class PybindWrapper(object): gtsam::deserialize(serialized, *self); }}, py::arg("serialized")) '''.format(class_inst=cpp_class + '*')) + if cpp_method == "pickle": + if not cpp_class in self._serializing_classes: + raise ValueError("Cannot pickle a class which is not serializable") + return textwrap.dedent(''' + .def(py::pickle( + [](const {cpp_class} &a){{ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }}, + [](py::tuple t){{ // __setstate__ + {cpp_class} obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + }})) + '''.format(cpp_class=cpp_class)) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) @@ -318,3 +333,4 @@ class PybindWrapper(object): wrapped_namespace=wrapped_namespace, boost_class_export=boost_class_export, ) + diff --git a/wrap/tests/expected-python/geometry_pybind.cpp b/wrap/tests/expected-python/geometry_pybind.cpp index 3eee55bf4..be6482d89 100644 --- a/wrap/tests/expected-python/geometry_pybind.cpp +++ b/wrap/tests/expected-python/geometry_pybind.cpp @@ -47,6 +47,17 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point2* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) + +.def(py::pickle( + [](const gtsam::Point2 &a){ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }, + [](py::tuple t){ // __setstate__ + gtsam::Point2 obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + })) ; py::class_>(m_gtsam, "Point3") @@ -62,6 +73,17 @@ PYBIND11_MODULE(geometry_py, m_) { gtsam::deserialize(serialized, *self); }, py::arg("serialized")) +.def(py::pickle( + [](const gtsam::Point3 &a){ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }, + [](py::tuple t){ // __setstate__ + gtsam::Point3 obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + })) + .def_static("staticFunction",[](){return gtsam::Point3::staticFunction();}) .def_static("StaticFunctionRet",[]( double z){return gtsam::Point3::StaticFunctionRet(z);}, py::arg("z")); diff --git a/wrap/tests/geometry.h b/wrap/tests/geometry.h index 40d878c9f..ec5d3b277 100644 --- a/wrap/tests/geometry.h +++ b/wrap/tests/geometry.h @@ -22,6 +22,9 @@ class Point2 { VectorNotEigen vectorConfusion(); void serializable() const; // Sets flag and creates export, but does not make serialization functions + + // enable pickling in python + void pickle() const; }; #include @@ -35,6 +38,9 @@ class Point3 { // enabling serialization functionality void serialize() const; // Just triggers a flag internally and removes actual function + + // enable pickling in python + void pickle() const; }; }