improve testing for deep copy

release/4.3a0
Varun Agrawal 2024-12-17 10:08:19 -05:00
parent 8f77bfa13b
commit 60b00ebdda
3 changed files with 32 additions and 24 deletions

View File

@ -5,11 +5,11 @@ All Rights Reserved
See LICENSE for the license information
KalmanFilter unit tests.
Author: Frank Dellaert & Duy Nguyen Ta (Python)
Serialization and deep copy tests.
Author: Varun Agrawal
"""
import unittest
from copy import deepcopy
import numpy as np
from gtsam.symbol_shorthand import B, V, X
@ -18,42 +18,36 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam
class TestSerialization(GtsamTestCase):
"""Tests for serialization of various GTSAM objects."""
class TestDeepCopy(GtsamTestCase):
"""Tests for deep copy of various GTSAM objects."""
def test_PreintegratedImuMeasurements(self):
"""
Test the serialization of `PreintegratedImuMeasurements` by performing a deepcopy.
Test the deep copy of `PreintegratedImuMeasurements` by performing a deepcopy.
"""
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedImuMeasurements(params)
# If serialization failed, then this will throw an error
pim2 = deepcopy(pim)
self.assertEqual(pim, pim2)
self.assertDeepCopyEquality(pim)
def test_ImuFactor(self):
"""
Test the serialization of `ImuFactor` by performing a deepcopy.
Test the deep copy of `ImuFactor` by performing a deepcopy.
"""
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedImuMeasurements(params)
imu_odom = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
imu_factor = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
# If serialization failed, then this will throw an error
imu_odom2 = deepcopy(imu_odom)
self.assertEqual(imu_odom, imu_odom2)
self.assertDeepCopyEquality(imu_factor)
def test_PreintegratedCombinedMeasurements(self):
"""
Test the serialization of `PreintegratedCombinedMeasurements` by performing a deepcopy.
Test the deep copy of `PreintegratedCombinedMeasurements` by performing a deepcopy.
"""
params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81]))
pim = gtsam.PreintegratedCombinedMeasurements(params)
# If serialization failed, then this will throw an error
pim2 = deepcopy(pim)
self.assertEqual(pim, pim2)
self.assertDeepCopyEquality(pim)
if __name__ == "__main__":

View File

@ -9,24 +9,26 @@ Unit tests to check pickling.
Author: Ayush Baid
"""
from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3
from gtsam.utils.test_case import GtsamTestCase
from gtsam import (Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3,
Pose3, Rot3, SfmTrack, Unit3)
class TestPickle(GtsamTestCase):
"""Tests pickling on some of the classes."""
def test_cal3Bundler_roundtrip(self):
obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70)
self.assertEqualityOnPickleRoundtrip(obj)
def test_pinholeCameraCal3Bundler_roundtrip(self):
obj = PinholeCameraCal3Bundler(
Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)),
Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70),
)
self.assertEqualityOnPickleRoundtrip(obj)
def test_rot3_roundtrip(self):
obj = Rot3.RzRyRx(0, 0.05, 0.1)
self.assertEqualityOnPickleRoundtrip(obj)

View File

@ -10,6 +10,7 @@ Author: Frank Dellaert
"""
import pickle
import unittest
from copy import deepcopy
class GtsamTestCase(unittest.TestCase):
@ -28,8 +29,8 @@ class GtsamTestCase(unittest.TestCase):
else:
equal = actual.equals(expected, tol)
if not equal:
raise self.failureException(
"Values are not equal:\n{}!={}".format(actual, expected))
raise self.failureException("Values are not equal:\n{}!={}".format(
actual, expected))
def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None:
""" Performs a round-trip using pickle and asserts equality.
@ -41,3 +42,14 @@ class GtsamTestCase(unittest.TestCase):
"""
roundTripObj = pickle.loads(pickle.dumps(obj))
self.gtsamAssertEquals(roundTripObj, obj)
def assertDeepCopyEquality(self, obj):
"""Perform assertion by checking if a
deep copied version of `obj` is equal to itself.
Args:
obj: The object to check is deep-copyable.
"""
# If deep copy failed, then this will throw an error
obj2 = deepcopy(obj)
self.gtsamAssertEquals(obj, obj2)