improve testing for deep copy
parent
8f77bfa13b
commit
60b00ebdda
|
@ -5,11 +5,11 @@ All Rights Reserved
|
|||
|
||||
See LICENSE for the license information
|
||||
|
||||
KalmanFilter unit tests.
|
||||
Author: Frank Dellaert & Duy Nguyen Ta (Python)
|
||||
Serialization and deep copy tests.
|
||||
|
||||
Author: Varun Agrawal
|
||||
"""
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import B, V, X
|
||||
|
@ -18,42 +18,36 @@ from gtsam.utils.test_case import GtsamTestCase
|
|||
import gtsam
|
||||
|
||||
|
||||
class TestSerialization(GtsamTestCase):
|
||||
"""Tests for serialization of various GTSAM objects."""
|
||||
class TestDeepCopy(GtsamTestCase):
|
||||
"""Tests for deep copy of various GTSAM objects."""
|
||||
|
||||
def test_PreintegratedImuMeasurements(self):
|
||||
"""
|
||||
Test the serialization of `PreintegratedImuMeasurements` by performing a deepcopy.
|
||||
Test the deep copy of `PreintegratedImuMeasurements` by performing a deepcopy.
|
||||
"""
|
||||
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
||||
pim = gtsam.PreintegratedImuMeasurements(params)
|
||||
|
||||
# If serialization failed, then this will throw an error
|
||||
pim2 = deepcopy(pim)
|
||||
self.assertEqual(pim, pim2)
|
||||
self.assertDeepCopyEquality(pim)
|
||||
|
||||
def test_ImuFactor(self):
|
||||
"""
|
||||
Test the serialization of `ImuFactor` by performing a deepcopy.
|
||||
Test the deep copy of `ImuFactor` by performing a deepcopy.
|
||||
"""
|
||||
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
||||
pim = gtsam.PreintegratedImuMeasurements(params)
|
||||
imu_odom = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
|
||||
imu_factor = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
|
||||
|
||||
# If serialization failed, then this will throw an error
|
||||
imu_odom2 = deepcopy(imu_odom)
|
||||
self.assertEqual(imu_odom, imu_odom2)
|
||||
self.assertDeepCopyEquality(imu_factor)
|
||||
|
||||
def test_PreintegratedCombinedMeasurements(self):
|
||||
"""
|
||||
Test the serialization of `PreintegratedCombinedMeasurements` by performing a deepcopy.
|
||||
Test the deep copy of `PreintegratedCombinedMeasurements` by performing a deepcopy.
|
||||
"""
|
||||
params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81]))
|
||||
pim = gtsam.PreintegratedCombinedMeasurements(params)
|
||||
|
||||
# If serialization failed, then this will throw an error
|
||||
pim2 = deepcopy(pim)
|
||||
self.assertEqual(pim, pim2)
|
||||
self.assertDeepCopyEquality(pim)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -9,24 +9,26 @@ Unit tests to check pickling.
|
|||
|
||||
Author: Ayush Baid
|
||||
"""
|
||||
from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3
|
||||
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam import (Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3,
|
||||
Pose3, Rot3, SfmTrack, Unit3)
|
||||
|
||||
|
||||
class TestPickle(GtsamTestCase):
|
||||
"""Tests pickling on some of the classes."""
|
||||
|
||||
def test_cal3Bundler_roundtrip(self):
|
||||
obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70)
|
||||
self.assertEqualityOnPickleRoundtrip(obj)
|
||||
|
||||
|
||||
def test_pinholeCameraCal3Bundler_roundtrip(self):
|
||||
obj = PinholeCameraCal3Bundler(
|
||||
Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)),
|
||||
Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70),
|
||||
)
|
||||
self.assertEqualityOnPickleRoundtrip(obj)
|
||||
|
||||
|
||||
def test_rot3_roundtrip(self):
|
||||
obj = Rot3.RzRyRx(0, 0.05, 0.1)
|
||||
self.assertEqualityOnPickleRoundtrip(obj)
|
||||
|
|
|
@ -10,6 +10,7 @@ Author: Frank Dellaert
|
|||
"""
|
||||
import pickle
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class GtsamTestCase(unittest.TestCase):
|
||||
|
@ -28,8 +29,8 @@ class GtsamTestCase(unittest.TestCase):
|
|||
else:
|
||||
equal = actual.equals(expected, tol)
|
||||
if not equal:
|
||||
raise self.failureException(
|
||||
"Values are not equal:\n{}!={}".format(actual, expected))
|
||||
raise self.failureException("Values are not equal:\n{}!={}".format(
|
||||
actual, expected))
|
||||
|
||||
def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None:
|
||||
""" Performs a round-trip using pickle and asserts equality.
|
||||
|
@ -41,3 +42,14 @@ class GtsamTestCase(unittest.TestCase):
|
|||
"""
|
||||
roundTripObj = pickle.loads(pickle.dumps(obj))
|
||||
self.gtsamAssertEquals(roundTripObj, obj)
|
||||
|
||||
def assertDeepCopyEquality(self, obj):
|
||||
"""Perform assertion by checking if a
|
||||
deep copied version of `obj` is equal to itself.
|
||||
|
||||
Args:
|
||||
obj: The object to check is deep-copyable.
|
||||
"""
|
||||
# If deep copy failed, then this will throw an error
|
||||
obj2 = deepcopy(obj)
|
||||
self.gtsamAssertEquals(obj, obj2)
|
||||
|
|
Loading…
Reference in New Issue