Expand DSF map unit tests

release/4.3a0
John Lambert 2021-08-21 20:16:39 -06:00 committed by GitHub
parent 52161785cf
commit 068e558d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 12 deletions

View File

@ -6,49 +6,68 @@ All Rights Reserved
See LICENSE for the license information
Unit tests for Disjoint Set Forest.
Author: Frank Dellaert & Varun Agrawal
Author: Frank Dellaert & Varun Agrawal & John Lambert
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
from typing import Tuple
import gtsam
from gtsam import IndexPair
from gtsam.utils.test_case import GtsamTestCase
class TestDSFMap(GtsamTestCase):
"""Tests for DSFMap."""
def test_all(self):
def test_all(self) -> None:
"""Test everything in DFSMap."""
def key(index_pair):
def key(index_pair) -> Tuple[int, int]:
return index_pair.i(), index_pair.j()
dsf = gtsam.DSFMapIndexPair()
pair1 = gtsam.IndexPair(1, 18)
self.assertEqual(key(dsf.find(pair1)), key(pair1))
pair2 = gtsam.IndexPair(2, 2)
# testing the merge feature of dsf
dsf.merge(pair1, pair2)
self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2)))
def test_sets(self):
from gtsam import IndexPair
def test_sets(self) -> None:
"""Ensure that unique keys are merged correctly during Union-Find.
An IndexPair (i,k) representing a unique key might represent the
k'th detected keypoint in image i. For the data below, merging such
measurements into feature tracks across frames should create 2 distinct sets.
"""
dsf = gtsam.DSFMapIndexPair()
dsf.merge(IndexPair(0, 1), IndexPair(1,2))
dsf.merge(IndexPair(0, 1), IndexPair(3,4))
dsf.merge(IndexPair(4,5), IndexPair(6,8))
dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
sets = dsf.sets()
merged_sets = set()
for i in sets:
set_keys = []
s = sets[i]
for val in gtsam.IndexPairSetAsArray(s):
val.i()
val.j()
set_keys.append((val.i(), val.j()))
merged_sets.add(tuple(set_keys))
# fmt: off
expected_sets = {
((0, 1), (1, 2), (3, 4)), # set 1
((4, 5), (6, 8)) # set 2
}
# fmt: on
assert expected_sets == merged_sets
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()