improve testing for deep copy

release/4.3a0
Varun Agrawal 2024-12-17 10:08:19 -05:00
parent 8f77bfa13b
commit 60b00ebdda
3 changed files with 32 additions and 24 deletions

View File

@ -5,11 +5,11 @@ All Rights Reserved
See LICENSE for the license information See LICENSE for the license information
KalmanFilter unit tests. Serialization and deep copy tests.
Author: Frank Dellaert & Duy Nguyen Ta (Python)
Author: Varun Agrawal
""" """
import unittest import unittest
from copy import deepcopy
import numpy as np import numpy as np
from gtsam.symbol_shorthand import B, V, X from gtsam.symbol_shorthand import B, V, X
@ -18,42 +18,36 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
class TestSerialization(GtsamTestCase): class TestDeepCopy(GtsamTestCase):
"""Tests for serialization of various GTSAM objects.""" """Tests for deep copy of various GTSAM objects."""
def test_PreintegratedImuMeasurements(self): def test_PreintegratedImuMeasurements(self):
""" """
Test the serialization of `PreintegratedImuMeasurements` by performing a deepcopy. Test the deep copy of `PreintegratedImuMeasurements` by performing a deepcopy.
""" """
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81])) params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedImuMeasurements(params) pim = gtsam.PreintegratedImuMeasurements(params)
# If serialization failed, then this will throw an error self.assertDeepCopyEquality(pim)
pim2 = deepcopy(pim)
self.assertEqual(pim, pim2)
def test_ImuFactor(self): def test_ImuFactor(self):
""" """
Test the serialization of `ImuFactor` by performing a deepcopy. Test the deep copy of `ImuFactor` by performing a deepcopy.
""" """
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81])) params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedImuMeasurements(params) pim = gtsam.PreintegratedImuMeasurements(params)
imu_odom = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim) imu_factor = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
# If serialization failed, then this will throw an error self.assertDeepCopyEquality(imu_factor)
imu_odom2 = deepcopy(imu_odom)
self.assertEqual(imu_odom, imu_odom2)
def test_PreintegratedCombinedMeasurements(self): def test_PreintegratedCombinedMeasurements(self):
""" """
Test the serialization of `PreintegratedCombinedMeasurements` by performing a deepcopy. Test the deep copy of `PreintegratedCombinedMeasurements` by performing a deepcopy.
""" """
params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81])) params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedCombinedMeasurements(params) pim = gtsam.PreintegratedCombinedMeasurements(params)
# If serialization failed, then this will throw an error self.assertDeepCopyEquality(pim)
pim2 = deepcopy(pim)
self.assertEqual(pim, pim2)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -9,10 +9,12 @@ Unit tests to check pickling.
Author: Ayush Baid Author: Ayush Baid
""" """
from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import (Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3,
Pose3, Rot3, SfmTrack, Unit3)
class TestPickle(GtsamTestCase): class TestPickle(GtsamTestCase):
"""Tests pickling on some of the classes.""" """Tests pickling on some of the classes."""

View File

@ -10,6 +10,7 @@ Author: Frank Dellaert
""" """
import pickle import pickle
import unittest import unittest
from copy import deepcopy
class GtsamTestCase(unittest.TestCase): class GtsamTestCase(unittest.TestCase):
@ -28,8 +29,8 @@ class GtsamTestCase(unittest.TestCase):
else: else:
equal = actual.equals(expected, tol) equal = actual.equals(expected, tol)
if not equal: if not equal:
raise self.failureException( raise self.failureException("Values are not equal:\n{}!={}".format(
"Values are not equal:\n{}!={}".format(actual, expected)) actual, expected))
def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None: def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None:
""" Performs a round-trip using pickle and asserts equality. """ Performs a round-trip using pickle and asserts equality.
@ -41,3 +42,14 @@ class GtsamTestCase(unittest.TestCase):
""" """
roundTripObj = pickle.loads(pickle.dumps(obj)) roundTripObj = pickle.loads(pickle.dumps(obj))
self.gtsamAssertEquals(roundTripObj, obj) self.gtsamAssertEquals(roundTripObj, obj)
def assertDeepCopyEquality(self, obj):
"""Perform assertion by checking if a
deep copied version of `obj` is equal to itself.
Args:
obj: The object to check is deep-copyable.
"""
# If deep copy failed, then this will throw an error
obj2 = deepcopy(obj)
self.gtsamAssertEquals(obj, obj2)