Fix 4.3 python style

release/4.3a0
Frank Dellaert 2023-09-02 16:55:38 -07:00
parent 12f919dc55
commit 2f2d6546d1
1 changed files with 35 additions and 35 deletions

View File

@ -4,14 +4,15 @@ Authors: John Lambert
""" """
import unittest import unittest
from typing import Dict, List, Tuple from typing import Dict, Tuple
import numpy as np import numpy as np
from gtsam.gtsfm import Keypoints from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import IndexPair, Point2, SfmTrack2d from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)
class TestDsfTrackGenerator(GtsamTestCase): class TestDsfTrackGenerator(GtsamTestCase):
@ -22,20 +23,21 @@ class TestDsfTrackGenerator(GtsamTestCase):
) -> None: ) -> None:
"""Tests DSF for non-transitive matches. """Tests DSF for non-transitive matches.
Test will result in no tracks since nontransitive tracks are naively discarded by DSF. Test will result in no tracks since nontransitive tracks are naively
discarded by DSF.
""" """
keypoints_list = get_dummy_keypoints_list() keypoints = get_dummy_keypoints_list()
nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track nontransitive_matches = get_nontransitive_matches()
# For each image pair (i1,i2), we provide a (K,2) matrix # For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2). # of corresponding keypoint indices (k1,k2).
matches_dict = {} matches = MatchIndicesMap()
for (i1,i2), corr_idxs in nontransitive_matches_dict.items(): for (i1, i2), correspondences in nontransitive_matches.items():
matches_dict[IndexPair(i1, i2)] = corr_idxs matches[IndexPair(i1, i2)] = correspondences
tracks = gtsam.gtsfm.tracksFromPairwiseMatches( tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict, matches,
keypoints_list, keypoints,
verbose=True, verbose=True,
) )
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly") self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
@ -47,20 +49,20 @@ class TestDsfTrackGenerator(GtsamTestCase):
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]])) kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]])) kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
keypoints_list = [] keypoints = KeypointsVector()
keypoints_list.append(kps_i0) keypoints.append(kps_i0)
keypoints_list.append(kps_i1) keypoints.append(kps_i1)
keypoints_list.append(kps_i2) keypoints.append(kps_i2)
# For each image pair (i1,i2), we provide a (K,2) matrix # For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2). # of corresponding image indices (k1,k2).
matches_dict = {} matches = MatchIndicesMap()
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
tracks = gtsam.gtsfm.tracksFromPairwiseMatches( tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict, matches,
keypoints_list, keypoints,
verbose=False, verbose=False,
) )
assert len(tracks) == 3 assert len(tracks) == 3
@ -110,17 +112,16 @@ class TestSfmTrack2d(GtsamTestCase):
def test_sfm_track_2d_constructor(self) -> None: def test_sfm_track_2d_constructor(self) -> None:
"""Test construction of 2D SfM track.""" """Test construction of 2D SfM track."""
measurements = [] measurements = SfmMeasurementVector()
measurements.append((0, Point2(10, 20))) measurements.append((0, Point2(10, 20)))
track = SfmTrack2d(measurements=measurements) track = SfmTrack2d(measurements=measurements)
track.measurement(0) track.measurement(0)
assert track.numberMeasurements() == 1 assert track.numberMeasurements() == 1
def get_dummy_keypoints_list() -> List[Keypoints]: def get_dummy_keypoints_list() -> KeypointsVector:
""" """ """Generate a list of dummy keypoints for testing."""
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]]) img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
img1_kp_scale = np.array([6.0, 9.0, 8.5])
img2_kp_coords = np.array( img2_kp_coords = np.array(
[ [
[1, 1.], [1, 1.],
@ -156,33 +157,32 @@ def get_dummy_keypoints_list() -> List[Keypoints]:
[5, 5], [5, 5],
] ]
) )
keypoints_list = [ keypoints = KeypointsVector()
Keypoints(coordinates=img1_kp_coords), keypoints.append(Keypoints(coordinates=img1_kp_coords))
Keypoints(coordinates=img2_kp_coords), keypoints.append(Keypoints(coordinates=img2_kp_coords))
Keypoints(coordinates=img3_kp_coords), keypoints.append(Keypoints(coordinates=img3_kp_coords))
Keypoints(coordinates=img4_kp_coords), keypoints.append(Keypoints(coordinates=img4_kp_coords))
] return keypoints
return keypoints_list
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]: def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
"""Set up correspondences for each (i1,i2) pair that violates transitivity. """Set up correspondences for each (i1,i2) pair that violates transitivity.
(i=0, k=0) (i=0, k=1) (i=0, k=0) (i=0, k=1)
| \\ | | \\ |
| \\ | | \\ |
(i=1, k=2)--(i=2,k=3)--(i=3, k=4) (i=1, k=2)--(i=2,k=3)--(i=3, k=4)
Transitivity is violated due to the match between frames 0 and 3. Transitivity is violated due to the match between frames 0 and 3.
""" """
nontransitive_matches_dict = { nontransitive_matches = {
(0, 1): np.array([[0, 2]]), (0, 1): np.array([[0, 2]]),
(1, 2): np.array([[2, 3]]), (1, 2): np.array([[2, 3]]),
(0, 2): np.array([[0, 3]]), (0, 2): np.array([[0, 3]]),
(0, 3): np.array([[1, 4]]), (0, 3): np.array([[1, 4]]),
(2, 3): np.array([[3, 4]]), (2, 3): np.array([[3, 4]]),
} }
return nontransitive_matches_dict return nontransitive_matches
if __name__ == "__main__": if __name__ == "__main__":