Merge pull request #854 from borglab/expand-dsf-map-unit-tests
commit
4739f70f40
|
@ -6,49 +6,68 @@ All Rights Reserved
|
||||||
See LICENSE for the license information
|
See LICENSE for the license information
|
||||||
|
|
||||||
Unit tests for Disjoint Set Forest.
|
Unit tests for Disjoint Set Forest.
|
||||||
Author: Frank Dellaert & Varun Agrawal
|
Author: Frank Dellaert & Varun Agrawal & John Lambert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
|
from gtsam import IndexPair
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestDSFMap(GtsamTestCase):
|
class TestDSFMap(GtsamTestCase):
|
||||||
"""Tests for DSFMap."""
|
"""Tests for DSFMap."""
|
||||||
|
|
||||||
def test_all(self):
|
def test_all(self) -> None:
|
||||||
"""Test everything in DFSMap."""
|
"""Test everything in DFSMap."""
|
||||||
def key(index_pair):
|
|
||||||
|
def key(index_pair) -> Tuple[int, int]:
|
||||||
return index_pair.i(), index_pair.j()
|
return index_pair.i(), index_pair.j()
|
||||||
|
|
||||||
dsf = gtsam.DSFMapIndexPair()
|
dsf = gtsam.DSFMapIndexPair()
|
||||||
pair1 = gtsam.IndexPair(1, 18)
|
pair1 = gtsam.IndexPair(1, 18)
|
||||||
self.assertEqual(key(dsf.find(pair1)), key(pair1))
|
self.assertEqual(key(dsf.find(pair1)), key(pair1))
|
||||||
pair2 = gtsam.IndexPair(2, 2)
|
pair2 = gtsam.IndexPair(2, 2)
|
||||||
|
|
||||||
# testing the merge feature of dsf
|
# testing the merge feature of dsf
|
||||||
dsf.merge(pair1, pair2)
|
dsf.merge(pair1, pair2)
|
||||||
self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2)))
|
self.assertEqual(key(dsf.find(pair1)), key(dsf.find(pair2)))
|
||||||
|
|
||||||
def test_sets(self):
|
def test_sets(self) -> None:
|
||||||
from gtsam import IndexPair
|
"""Ensure that pairs are merged correctly during Union-Find.
|
||||||
|
|
||||||
|
An IndexPair (i,k) representing a unique key might represent the
|
||||||
|
k'th detected keypoint in image i. For the data below, merging such
|
||||||
|
measurements into feature tracks across frames should create 2 distinct sets.
|
||||||
|
"""
|
||||||
dsf = gtsam.DSFMapIndexPair()
|
dsf = gtsam.DSFMapIndexPair()
|
||||||
dsf.merge(IndexPair(0, 1), IndexPair(1,2))
|
dsf.merge(IndexPair(0, 1), IndexPair(1, 2))
|
||||||
dsf.merge(IndexPair(0, 1), IndexPair(3,4))
|
dsf.merge(IndexPair(0, 1), IndexPair(3, 4))
|
||||||
dsf.merge(IndexPair(4,5), IndexPair(6,8))
|
dsf.merge(IndexPair(4, 5), IndexPair(6, 8))
|
||||||
sets = dsf.sets()
|
sets = dsf.sets()
|
||||||
|
|
||||||
|
merged_sets = set()
|
||||||
|
|
||||||
for i in sets:
|
for i in sets:
|
||||||
|
set_keys = []
|
||||||
s = sets[i]
|
s = sets[i]
|
||||||
for val in gtsam.IndexPairSetAsArray(s):
|
for val in gtsam.IndexPairSetAsArray(s):
|
||||||
val.i()
|
set_keys.append((val.i(), val.j()))
|
||||||
val.j()
|
merged_sets.add(tuple(set_keys))
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
expected_sets = {
|
||||||
|
((0, 1), (1, 2), (3, 4)), # set 1
|
||||||
|
((4, 5), (6, 8)) # set 2
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
assert expected_sets == merged_sets
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue