diff --git a/python/gtsam/tests/test_Serialization.py b/python/gtsam/tests/test_Serialization.py index 3e22e6603..e935f1f62 100644 --- a/python/gtsam/tests/test_Serialization.py +++ b/python/gtsam/tests/test_Serialization.py @@ -5,11 +5,11 @@ All Rights Reserved See LICENSE for the license information -KalmanFilter unit tests. -Author: Frank Dellaert & Duy Nguyen Ta (Python) +Serialization and deep copy tests. + +Author: Varun Agrawal """ import unittest -from copy import deepcopy import numpy as np from gtsam.symbol_shorthand import B, V, X @@ -18,42 +18,36 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam -class TestSerialization(GtsamTestCase): - """Tests for serialization of various GTSAM objects.""" +class TestDeepCopy(GtsamTestCase): + """Tests for deep copy of various GTSAM objects.""" def test_PreintegratedImuMeasurements(self): """ - Test the serialization of `PreintegratedImuMeasurements` by performing a deepcopy. + Test the deep copy of `PreintegratedImuMeasurements` by performing a deepcopy. """ params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81])) pim = gtsam.PreintegratedImuMeasurements(params) - # If serialization failed, then this will throw an error - pim2 = deepcopy(pim) - self.assertEqual(pim, pim2) + self.assertDeepCopyEquality(pim) def test_ImuFactor(self): """ - Test the serialization of `ImuFactor` by performing a deepcopy. + Test the deep copy of `ImuFactor` by performing a deepcopy. """ params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81])) pim = gtsam.PreintegratedImuMeasurements(params) - imu_odom = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim) + imu_factor = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim) - # If serialization failed, then this will throw an error - imu_odom2 = deepcopy(imu_odom) - self.assertEqual(imu_odom, imu_odom2) + self.assertDeepCopyEquality(imu_factor) def test_PreintegratedCombinedMeasurements(self): """ - Test the serialization of `PreintegratedCombinedMeasurements` by performing a deepcopy. + Test the deep copy of `PreintegratedCombinedMeasurements` by performing a deepcopy. """ params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81])) pim = gtsam.PreintegratedCombinedMeasurements(params) - # If serialization failed, then this will throw an error - pim2 = deepcopy(pim) - self.assertEqual(pim, pim2) + self.assertDeepCopyEquality(pim) if __name__ == "__main__": diff --git a/python/gtsam/tests/test_pickle.py b/python/gtsam/tests/test_pickle.py index a6a5745bc..e51617b00 100644 --- a/python/gtsam/tests/test_pickle.py +++ b/python/gtsam/tests/test_pickle.py @@ -9,24 +9,26 @@ Unit tests to check pickling. Author: Ayush Baid """ -from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3 - from gtsam.utils.test_case import GtsamTestCase +from gtsam import (Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, + Pose3, Rot3, SfmTrack, Unit3) + + class TestPickle(GtsamTestCase): """Tests pickling on some of the classes.""" def test_cal3Bundler_roundtrip(self): obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70) self.assertEqualityOnPickleRoundtrip(obj) - + def test_pinholeCameraCal3Bundler_roundtrip(self): obj = PinholeCameraCal3Bundler( Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)), Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70), ) self.assertEqualityOnPickleRoundtrip(obj) - + def test_rot3_roundtrip(self): obj = Rot3.RzRyRx(0, 0.05, 0.1) self.assertEqualityOnPickleRoundtrip(obj) diff --git a/python/gtsam/utils/test_case.py b/python/gtsam/utils/test_case.py index 50af004f4..74eaff1db 100644 --- a/python/gtsam/utils/test_case.py +++ b/python/gtsam/utils/test_case.py @@ -10,6 +10,7 @@ Author: Frank Dellaert """ import pickle import unittest +from copy import deepcopy class GtsamTestCase(unittest.TestCase): @@ -28,8 +29,8 @@ class GtsamTestCase(unittest.TestCase): else: equal = actual.equals(expected, tol) if not equal: - raise self.failureException( - "Values are not equal:\n{}!={}".format(actual, expected)) + raise self.failureException("Values are not equal:\n{}!={}".format( + actual, expected)) def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None: """ Performs a round-trip using pickle and asserts equality. @@ -41,3 +42,14 @@ class GtsamTestCase(unittest.TestCase): """ roundTripObj = pickle.loads(pickle.dumps(obj)) self.gtsamAssertEquals(roundTripObj, obj) + + def assertDeepCopyEquality(self, obj): + """Perform assertion by checking if a + deep copied version of `obj` is equal to itself. + + Args: + obj: The object to check is deep-copyable. + """ + # If deep copy failed, then this will throw an error + obj2 = deepcopy(obj) + self.gtsamAssertEquals(obj, obj2)