improve testing for deep copy
parent
8f77bfa13b
commit
60b00ebdda
|
@ -5,11 +5,11 @@ All Rights Reserved
|
||||||
|
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
KalmanFilter unit tests.
|
Serialization and deep copy tests.
|
||||||
Author: Frank Dellaert & Duy Nguyen Ta (Python)
|
|
||||||
|
Author: Varun Agrawal
|
||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.symbol_shorthand import B, V, X
|
from gtsam.symbol_shorthand import B, V, X
|
||||||
|
@ -18,42 +18,36 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
import gtsam
|
import gtsam
|
||||||
|
|
||||||
|
|
||||||
class TestSerialization(GtsamTestCase):
|
class TestDeepCopy(GtsamTestCase):
|
||||||
"""Tests for serialization of various GTSAM objects."""
|
"""Tests for deep copy of various GTSAM objects."""
|
||||||
|
|
||||||
def test_PreintegratedImuMeasurements(self):
|
def test_PreintegratedImuMeasurements(self):
|
||||||
"""
|
"""
|
||||||
Test the serialization of `PreintegratedImuMeasurements` by performing a deepcopy.
|
Test the deep copy of `PreintegratedImuMeasurements` by performing a deepcopy.
|
||||||
"""
|
"""
|
||||||
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
||||||
pim = gtsam.PreintegratedImuMeasurements(params)
|
pim = gtsam.PreintegratedImuMeasurements(params)
|
||||||
|
|
||||||
# If serialization failed, then this will throw an error
|
self.assertDeepCopyEquality(pim)
|
||||||
pim2 = deepcopy(pim)
|
|
||||||
self.assertEqual(pim, pim2)
|
|
||||||
|
|
||||||
def test_ImuFactor(self):
|
def test_ImuFactor(self):
|
||||||
"""
|
"""
|
||||||
Test the serialization of `ImuFactor` by performing a deepcopy.
|
Test the deep copy of `ImuFactor` by performing a deepcopy.
|
||||||
"""
|
"""
|
||||||
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
params = gtsam.PreintegrationParams(np.asarray([0, 0, -9.81]))
|
||||||
pim = gtsam.PreintegratedImuMeasurements(params)
|
pim = gtsam.PreintegratedImuMeasurements(params)
|
||||||
imu_odom = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
|
imu_factor = gtsam.ImuFactor(X(0), V(0), X(1), V(1), B(0), pim)
|
||||||
|
|
||||||
# If serialization failed, then this will throw an error
|
self.assertDeepCopyEquality(imu_factor)
|
||||||
imu_odom2 = deepcopy(imu_odom)
|
|
||||||
self.assertEqual(imu_odom, imu_odom2)
|
|
||||||
|
|
||||||
def test_PreintegratedCombinedMeasurements(self):
|
def test_PreintegratedCombinedMeasurements(self):
|
||||||
"""
|
"""
|
||||||
Test the serialization of `PreintegratedCombinedMeasurements` by performing a deepcopy.
|
Test the deep copy of `PreintegratedCombinedMeasurements` by performing a deepcopy.
|
||||||
"""
|
"""
|
||||||
params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81]))
|
params = gtsam.PreintegrationCombinedParams(np.asarray([0, 0, -9.81]))
|
||||||
pim = gtsam.PreintegratedCombinedMeasurements(params)
|
pim = gtsam.PreintegratedCombinedMeasurements(params)
|
||||||
|
|
||||||
# If serialization failed, then this will throw an error
|
self.assertDeepCopyEquality(pim)
|
||||||
pim2 = deepcopy(pim)
|
|
||||||
self.assertEqual(pim, pim2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -9,24 +9,26 @@ Unit tests to check pickling.
|
||||||
|
|
||||||
Author: Ayush Baid
|
Author: Ayush Baid
|
||||||
"""
|
"""
|
||||||
from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3
|
|
||||||
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
from gtsam import (Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3,
|
||||||
|
Pose3, Rot3, SfmTrack, Unit3)
|
||||||
|
|
||||||
|
|
||||||
class TestPickle(GtsamTestCase):
|
class TestPickle(GtsamTestCase):
|
||||||
"""Tests pickling on some of the classes."""
|
"""Tests pickling on some of the classes."""
|
||||||
|
|
||||||
def test_cal3Bundler_roundtrip(self):
|
def test_cal3Bundler_roundtrip(self):
|
||||||
obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70)
|
obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70)
|
||||||
self.assertEqualityOnPickleRoundtrip(obj)
|
self.assertEqualityOnPickleRoundtrip(obj)
|
||||||
|
|
||||||
def test_pinholeCameraCal3Bundler_roundtrip(self):
|
def test_pinholeCameraCal3Bundler_roundtrip(self):
|
||||||
obj = PinholeCameraCal3Bundler(
|
obj = PinholeCameraCal3Bundler(
|
||||||
Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)),
|
Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)),
|
||||||
Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70),
|
Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70),
|
||||||
)
|
)
|
||||||
self.assertEqualityOnPickleRoundtrip(obj)
|
self.assertEqualityOnPickleRoundtrip(obj)
|
||||||
|
|
||||||
def test_rot3_roundtrip(self):
|
def test_rot3_roundtrip(self):
|
||||||
obj = Rot3.RzRyRx(0, 0.05, 0.1)
|
obj = Rot3.RzRyRx(0, 0.05, 0.1)
|
||||||
self.assertEqualityOnPickleRoundtrip(obj)
|
self.assertEqualityOnPickleRoundtrip(obj)
|
||||||
|
|
|
@ -10,6 +10,7 @@ Author: Frank Dellaert
|
||||||
"""
|
"""
|
||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class GtsamTestCase(unittest.TestCase):
|
class GtsamTestCase(unittest.TestCase):
|
||||||
|
@ -28,8 +29,8 @@ class GtsamTestCase(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
equal = actual.equals(expected, tol)
|
equal = actual.equals(expected, tol)
|
||||||
if not equal:
|
if not equal:
|
||||||
raise self.failureException(
|
raise self.failureException("Values are not equal:\n{}!={}".format(
|
||||||
"Values are not equal:\n{}!={}".format(actual, expected))
|
actual, expected))
|
||||||
|
|
||||||
def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None:
|
def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None:
|
||||||
""" Performs a round-trip using pickle and asserts equality.
|
""" Performs a round-trip using pickle and asserts equality.
|
||||||
|
@ -41,3 +42,14 @@ class GtsamTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
roundTripObj = pickle.loads(pickle.dumps(obj))
|
roundTripObj = pickle.loads(pickle.dumps(obj))
|
||||||
self.gtsamAssertEquals(roundTripObj, obj)
|
self.gtsamAssertEquals(roundTripObj, obj)
|
||||||
|
|
||||||
|
def assertDeepCopyEquality(self, obj):
|
||||||
|
"""Perform assertion by checking if a
|
||||||
|
deep copied version of `obj` is equal to itself.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to check is deep-copyable.
|
||||||
|
"""
|
||||||
|
# If deep copy failed, then this will throw an error
|
||||||
|
obj2 = deepcopy(obj)
|
||||||
|
self.gtsamAssertEquals(obj, obj2)
|
||||||
|
|
Loading…
Reference in New Issue