Cherrypick transitivity fix for DsfTrackGenerator
							parent
							
								
									13c7dafba3
								
							
						
					
					
						commit
						12f919dc55
					
				|  | @ -20,6 +20,7 @@ | |||
| 
 | ||||
| #include <algorithm> | ||||
| #include <iostream> | ||||
| #include <iomanip> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
|  | @ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) { | |||
|     // Image pair is (i1,i2).
 | ||||
|     size_t i1 = pair_indices.first; | ||||
|     size_t i2 = pair_indices.second; | ||||
|     for (size_t k = 0; k < corr_indices.rows(); k++) { | ||||
|     size_t m = static_cast<size_t>(corr_indices.rows()); | ||||
|     for (size_t k = 0; k < m; k++) { | ||||
|       // Measurement indices are found in a single matrix row, as (k1,k2).
 | ||||
|       size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1); | ||||
|       // Unique key for DSF is (i,k), representing keypoint index in an image.
 | ||||
|  | @ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches( | |||
|   } | ||||
| 
 | ||||
|   // TODO(johnwlambert): return the Transitivity failure percentage here.
 | ||||
|   return tracks2d; | ||||
|   return validTracks; | ||||
| } | ||||
| 
 | ||||
| }  // namespace gtsfm
 | ||||
|  |  | |||
|  | @ -4,18 +4,42 @@ Authors: John Lambert | |||
| """ | ||||
| 
 | ||||
| import unittest | ||||
| from typing import Dict, List, Tuple | ||||
| 
 | ||||
| import gtsam | ||||
| import numpy as np | ||||
| from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2, | ||||
|                    SfmMeasurementVector, SfmTrack2d) | ||||
| from gtsam.gtsfm import Keypoints | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| import gtsam | ||||
| from gtsam import IndexPair, Point2, SfmTrack2d | ||||
| 
 | ||||
| 
 | ||||
| class TestDsfTrackGenerator(GtsamTestCase): | ||||
|     """Tests for DsfTrackGenerator.""" | ||||
| 
 | ||||
|     def test_generate_tracks_from_pairwise_matches_nontransitive( | ||||
|         self, | ||||
|     ) -> None: | ||||
|         """Tests DSF for non-transitive matches. | ||||
| 
 | ||||
|         Test will result in no tracks since nontransitive tracks are naively discarded by DSF. | ||||
|         """ | ||||
|         keypoints_list = get_dummy_keypoints_list() | ||||
|         nontransitive_matches_dict = get_nontransitive_matches()  # contains one non-transitive track | ||||
| 
 | ||||
|         # For each image pair (i1,i2), we provide a (K,2) matrix | ||||
|         # of corresponding keypoint indices (k1,k2). | ||||
|         matches_dict = {} | ||||
|         for (i1,i2), corr_idxs in nontransitive_matches_dict.items(): | ||||
|             matches_dict[IndexPair(i1, i2)] = corr_idxs | ||||
| 
 | ||||
|         tracks = gtsam.gtsfm.tracksFromPairwiseMatches( | ||||
|             matches_dict, | ||||
|             keypoints_list, | ||||
|             verbose=True, | ||||
|         ) | ||||
|         self.assertEqual(len(tracks), 0, "Tracks not filtered correctly") | ||||
| 
 | ||||
|     def test_track_generation(self) -> None: | ||||
|         """Ensures that DSF generates three tracks from measurements | ||||
|         in 3 images (H=200,W=400).""" | ||||
|  | @ -23,14 +47,14 @@ class TestDsfTrackGenerator(GtsamTestCase): | |||
|         kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]])) | ||||
|         kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]])) | ||||
| 
 | ||||
|         keypoints_list = KeypointsVector() | ||||
|         keypoints_list = [] | ||||
|         keypoints_list.append(kps_i0) | ||||
|         keypoints_list.append(kps_i1) | ||||
|         keypoints_list.append(kps_i2) | ||||
| 
 | ||||
|         # For each image pair (i1,i2), we provide a (K,2) matrix | ||||
|         # of corresponding image indices (k1,k2). | ||||
|         matches_dict = MatchIndicesMap() | ||||
|         # of corresponding keypoint indices (k1,k2). | ||||
|         matches_dict = {} | ||||
|         matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) | ||||
|         matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) | ||||
| 
 | ||||
|  | @ -86,12 +110,80 @@ class TestSfmTrack2d(GtsamTestCase): | |||
| 
 | ||||
|     def test_sfm_track_2d_constructor(self) -> None: | ||||
|         """Test construction of 2D SfM track.""" | ||||
|         measurements = SfmMeasurementVector() | ||||
|         measurements = [] | ||||
|         measurements.append((0, Point2(10, 20))) | ||||
|         track = SfmTrack2d(measurements=measurements) | ||||
|         track.measurement(0) | ||||
|         assert track.numberMeasurements() == 1 | ||||
| 
 | ||||
| 
 | ||||
| def get_dummy_keypoints_list() -> List[Keypoints]: | ||||
|     """ """ | ||||
|     img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]]) | ||||
|     img1_kp_scale = np.array([6.0, 9.0, 8.5]) | ||||
|     img2_kp_coords = np.array( | ||||
|         [ | ||||
|             [1, 1.], | ||||
|             [2, 2], | ||||
|             [3, 3], | ||||
|             [4, 4], | ||||
|             [5, 5], | ||||
|             [6, 6], | ||||
|             [7, 7], | ||||
|             [8, 8], | ||||
|         ] | ||||
|     ) | ||||
|     img3_kp_coords = np.array( | ||||
|         [ | ||||
|             [1, 1.], | ||||
|             [2, 2], | ||||
|             [3, 3], | ||||
|             [4, 4], | ||||
|             [5, 5], | ||||
|             [6, 6], | ||||
|             [7, 7], | ||||
|             [8, 8], | ||||
|             [9, 9], | ||||
|             [10, 10], | ||||
|         ] | ||||
|     ) | ||||
|     img4_kp_coords = np.array( | ||||
|         [ | ||||
|             [1, 1.], | ||||
|             [2, 2], | ||||
|             [3, 3], | ||||
|             [4, 4], | ||||
|             [5, 5], | ||||
|         ] | ||||
|     ) | ||||
|     keypoints_list = [ | ||||
|         Keypoints(coordinates=img1_kp_coords), | ||||
|         Keypoints(coordinates=img2_kp_coords), | ||||
|         Keypoints(coordinates=img3_kp_coords), | ||||
|         Keypoints(coordinates=img4_kp_coords), | ||||
|     ] | ||||
|     return keypoints_list | ||||
| 
 | ||||
| 
 | ||||
| def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]: | ||||
|     """Set up correspondences for each (i1,i2) pair that violates transitivity. | ||||
|      | ||||
|     (i=0, k=0)             (i=0, k=1) | ||||
|          |    \\               | | ||||
|          |     \\              | | ||||
|     (i=1, k=2)--(i=2,k=3)--(i=3, k=4) | ||||
| 
 | ||||
|     Transitivity is violated due to the match between frames 0 and 3.  | ||||
|     """ | ||||
|     nontransitive_matches_dict = { | ||||
|         (0, 1): np.array([[0, 2]]), | ||||
|         (1, 2): np.array([[2, 3]]), | ||||
|         (0, 2): np.array([[0, 3]]), | ||||
|         (0, 3): np.array([[1, 4]]), | ||||
|         (2, 3): np.array([[3, 4]]), | ||||
|     } | ||||
|     return nontransitive_matches_dict | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue