Cherrypick transitivity fix for DsfTrackGenerator

release/4.3a0
senselessdev1 2023-08-30 22:09:07 -04:00
parent 13c7dafba3
commit 12f919dc55
2 changed files with 103 additions and 9 deletions

View File

@ -20,6 +20,7 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <iomanip>
namespace gtsam { namespace gtsam {
@ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
// Image pair is (i1,i2). // Image pair is (i1,i2).
size_t i1 = pair_indices.first; size_t i1 = pair_indices.first;
size_t i2 = pair_indices.second; size_t i2 = pair_indices.second;
for (size_t k = 0; k < corr_indices.rows(); k++) { size_t m = static_cast<size_t>(corr_indices.rows());
for (size_t k = 0; k < m; k++) {
// Measurement indices are found in a single matrix row, as (k1,k2). // Measurement indices are found in a single matrix row, as (k1,k2).
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1); size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
// Unique key for DSF is (i,k), representing keypoint index in an image. // Unique key for DSF is (i,k), representing keypoint index in an image.
@ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches(
} }
// TODO(johnwlambert): return the Transitivity failure percentage here. // TODO(johnwlambert): return the Transitivity failure percentage here.
return tracks2d; return validTracks;
} }
} // namespace gtsfm } // namespace gtsfm

View File

@ -4,18 +4,42 @@ Authors: John Lambert
""" """
import unittest import unittest
from typing import Dict, List, Tuple
import gtsam
import numpy as np import numpy as np
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)
from gtsam.gtsfm import Keypoints from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import IndexPair, Point2, SfmTrack2d
class TestDsfTrackGenerator(GtsamTestCase): class TestDsfTrackGenerator(GtsamTestCase):
"""Tests for DsfTrackGenerator.""" """Tests for DsfTrackGenerator."""
def test_generate_tracks_from_pairwise_matches_nontransitive(
self,
) -> None:
"""Tests DSF for non-transitive matches.
Test will result in no tracks since nontransitive tracks are naively discarded by DSF.
"""
keypoints_list = get_dummy_keypoints_list()
nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track
# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2).
matches_dict = {}
for (i1,i2), corr_idxs in nontransitive_matches_dict.items():
matches_dict[IndexPair(i1, i2)] = corr_idxs
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
verbose=True,
)
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
def test_track_generation(self) -> None: def test_track_generation(self) -> None:
"""Ensures that DSF generates three tracks from measurements """Ensures that DSF generates three tracks from measurements
in 3 images (H=200,W=400).""" in 3 images (H=200,W=400)."""
@ -23,14 +47,14 @@ class TestDsfTrackGenerator(GtsamTestCase):
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]])) kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]])) kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
keypoints_list = KeypointsVector() keypoints_list = []
keypoints_list.append(kps_i0) keypoints_list.append(kps_i0)
keypoints_list.append(kps_i1) keypoints_list.append(kps_i1)
keypoints_list.append(kps_i2) keypoints_list.append(kps_i2)
# For each image pair (i1,i2), we provide a (K,2) matrix # For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding image indices (k1,k2). # of corresponding keypoint indices (k1,k2).
matches_dict = MatchIndicesMap() matches_dict = {}
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
@ -86,12 +110,80 @@ class TestSfmTrack2d(GtsamTestCase):
def test_sfm_track_2d_constructor(self) -> None: def test_sfm_track_2d_constructor(self) -> None:
"""Test construction of 2D SfM track.""" """Test construction of 2D SfM track."""
measurements = SfmMeasurementVector() measurements = []
measurements.append((0, Point2(10, 20))) measurements.append((0, Point2(10, 20)))
track = SfmTrack2d(measurements=measurements) track = SfmTrack2d(measurements=measurements)
track.measurement(0) track.measurement(0)
assert track.numberMeasurements() == 1 assert track.numberMeasurements() == 1
def get_dummy_keypoints_list() -> List[Keypoints]:
""" """
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
img1_kp_scale = np.array([6.0, 9.0, 8.5])
img2_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
]
)
img3_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
[9, 9],
[10, 10],
]
)
img4_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
]
)
keypoints_list = [
Keypoints(coordinates=img1_kp_coords),
Keypoints(coordinates=img2_kp_coords),
Keypoints(coordinates=img3_kp_coords),
Keypoints(coordinates=img4_kp_coords),
]
return keypoints_list
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
(i=0, k=0) (i=0, k=1)
| \\ |
| \\ |
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
Transitivity is violated due to the match between frames 0 and 3.
"""
nontransitive_matches_dict = {
(0, 1): np.array([[0, 2]]),
(1, 2): np.array([[2, 3]]),
(0, 2): np.array([[0, 3]]),
(0, 3): np.array([[1, 4]]),
(2, 3): np.array([[3, 4]]),
}
return nontransitive_matches_dict
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()