diff --git a/python/gtsam/tests/test_dsf_map.py b/python/gtsam/tests/test_dsf_map.py index e36657178..0ab8b678f 100644 --- a/python/gtsam/tests/test_dsf_map.py +++ b/python/gtsam/tests/test_dsf_map.py @@ -6,49 +6,68 @@ All Rights Reserved See LICENSE for the license information Unit tests for Disjoint Set Forest. -Author: Frank Dellaert & Varun Agrawal +Author: Frank Dellaert & Varun Agrawal & John Lambert """ # pylint: disable=invalid-name, no-name-in-module, no-member from __future__ import print_function import unittest +from typing import Tuple import gtsam +from gtsam import IndexPair from gtsam.utils.test_case import GtsamTestCase class TestDSFMap(GtsamTestCase): """Tests for DSFMap.""" - def test_all(self): + def test_all(self) -> None: """Test everything in DFSMap.""" - def key(index_pair): + + def key(index_pair) -> Tuple[int, int]: return index_pair.i(), index_pair.j() dsf = gtsam.DSFMapIndexPair() pair1 = gtsam.IndexPair(1, 18) self.assertEqual(key(dsf.find(pair1)), key(pair1)) pair2 = gtsam.IndexPair(2, 2) - + # testing the merge feature of dsf dsf.merge(pair1, pair2) self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2))) - def test_sets(self): - from gtsam import IndexPair + def test_sets(self) -> None: + """Ensure that unique keys are merged correctly during Union-Find. + + An IndexPair (i,k) representing a unique key might represent the + k'th detected keypoint in image i. For the data below, merging such + measurements into feature tracks across frames should create 2 distinct sets. + """ dsf = gtsam.DSFMapIndexPair() - dsf.merge(IndexPair(0, 1), IndexPair(1,2)) - dsf.merge(IndexPair(0, 1), IndexPair(3,4)) - dsf.merge(IndexPair(4,5), IndexPair(6,8)) + dsf.merge(IndexPair(0, 1), IndexPair(1, 2)) + dsf.merge(IndexPair(0, 1), IndexPair(3, 4)) + dsf.merge(IndexPair(4, 5), IndexPair(6, 8)) sets = dsf.sets() + merged_sets = set() + for i in sets: + set_keys = [] s = sets[i] for val in gtsam.IndexPairSetAsArray(s): - val.i() - val.j() + set_keys.append((val.i(), val.j())) + merged_sets.add(tuple(set_keys)) + + # fmt: off + expected_sets = { + ((0, 1), (1, 2), (3, 4)), # set 1 + ((4, 5), (6, 8)) # set 2 + } + # fmt: on + assert expected_sets == merged_sets -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()