Expand DSF map unit tests

release/4.3a0
John Lambert 2021-08-21 20:16:39 -06:00 committed by GitHub
parent 52161785cf
commit 068e558d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 12 deletions

View File

@ -6,49 +6,68 @@ All Rights Reserved
See LICENSE for the license information See LICENSE for the license information
Unit tests for Disjoint Set Forest. Unit tests for Disjoint Set Forest.
Author: Frank Dellaert & Varun Agrawal Author: Frank Dellaert & Varun Agrawal & John Lambert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from typing import Tuple
import gtsam import gtsam
from gtsam import IndexPair
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
class TestDSFMap(GtsamTestCase): class TestDSFMap(GtsamTestCase):
"""Tests for DSFMap.""" """Tests for DSFMap."""
def test_all(self): def test_all(self) -> None:
"""Test everything in DFSMap.""" """Test everything in DFSMap."""
def key(index_pair):
def key(index_pair) -> Tuple[int, int]:
return index_pair.i(), index_pair.j() return index_pair.i(), index_pair.j()
dsf = gtsam.DSFMapIndexPair() dsf = gtsam.DSFMapIndexPair()
pair1 = gtsam.IndexPair(1, 18) pair1 = gtsam.IndexPair(1, 18)
self.assertEqual(key(dsf.find(pair1)), key(pair1)) self.assertEqual(key(dsf.find(pair1)), key(pair1))
pair2 = gtsam.IndexPair(2, 2) pair2 = gtsam.IndexPair(2, 2)
# testing the merge feature of dsf # testing the merge feature of dsf
dsf.merge(pair1, pair2) dsf.merge(pair1, pair2)
self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2))) self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2)))
def test_sets(self): def test_sets(self) -> None:
from gtsam import IndexPair """Ensure that unique keys are merged correctly during Union-Find.
An IndexPair (i,k) representing a unique key might represent the
k'th detected keypoint in image i. For the data below, merging such
measurements into feature tracks across frames should create 2 distinct sets.
"""
dsf = gtsam.DSFMapIndexPair() dsf = gtsam.DSFMapIndexPair()
dsf.merge(IndexPair(0, 1), IndexPair(1,2)) dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
dsf.merge(IndexPair(0, 1), IndexPair(3,4)) dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
dsf.merge(IndexPair(4,5), IndexPair(6,8)) dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
sets = dsf.sets() sets = dsf.sets()
merged_sets = set()
for i in sets: for i in sets:
set_keys = []
s = sets[i] s = sets[i]
for val in gtsam.IndexPairSetAsArray(s): for val in gtsam.IndexPairSetAsArray(s):
val.i() set_keys.append((val.i(), val.j()))
val.j() merged_sets.add(tuple(set_keys))
# fmt: off
expected_sets = {
((0, 1), (1, 2), (3, 4)), # set 1
((4, 5), (6, 8)) # set 2
}
# fmt: on
assert expected_sets == merged_sets
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()