Added DSFMap to wrapper, as well as IndexPair

release/4.3a0
dellaert 2019-04-17 20:05:28 -04:00
parent 000ccc0bcc
commit 85934fd8ca
4 changed files with 79 additions and 18 deletions

View File

@ -0,0 +1,38 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Disjoint Set Forest.
Author: Frank Dellaert & Varun Agrawal
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
import gtsam
from gtsam.utils.test_case import GtsamTestCase
class TestDSFMap(GtsamTestCase):
"""Tests for DSFMap."""
def test_all(self):
"""Test everything in DFSMap."""
def key(index_pair):
return index_pair.i(), index_pair.j()
dsf = gtsam.DSFMapIndexPair()
pair1 = gtsam.IndexPair(1, 18)
self.assertEqual(key(dsf.find(pair1)), key(pair1))
pair2 = gtsam.IndexPair(2, 2)
dsf.merge(pair1, pair2)
self.assertTrue(dsf.find(pair1), dsf.find(pair1))
if __name__ == '__main__':
unittest.main()

17
gtsam.h
View File

@ -241,11 +241,28 @@ class FactorIndices {
size_t back() const;
void push_back(size_t factorIndex) const;
};
//*************************************************************************
// base
//*************************************************************************
/** gtsam namespace functions */
#include <gtsam/base/DSFMap.h>
class IndexPair {
IndexPair();
IndexPair(size_t i, size_t j);
size_t i() const;
size_t j() const;
};
template<KEY = {gtsam::IndexPair}>
class DSFMap {
DSFMap();
KEY find(const KEY& key) const;
void merge(const KEY& x, const KEY& y);
};
#include <gtsam/base/Matrix.h>
bool linear_independent(Matrix A, Matrix B, double tol);

View File

@ -1,6 +1,6 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -18,9 +18,9 @@
#pragma once
#include <cstdlib> // Provides size_t
#include <map>
#include <set>
#include <cstdlib> // Provides size_t
namespace gtsam {
@ -29,11 +29,9 @@ namespace gtsam {
* Uses rank compression and union by rank, iterator version
* @addtogroup base
*/
template<class KEY>
template <class KEY>
class DSFMap {
protected:
protected:
/// We store the forest in an STL map, but parents are done with pointers
struct Entry {
typename std::map<KEY, Entry>::iterator parent_;
@ -41,7 +39,7 @@ protected:
Entry() {}
};
typedef typename std::map<KEY, Entry> Map;
typedef typename std::map<KEY, Entry> Map;
typedef typename Map::iterator iterator;
mutable Map entries_;
@ -62,8 +60,7 @@ protected:
iterator find_(const iterator& it) const {
// follow parent pointers until we reach set representative
iterator& parent = it->second.parent_;
if (parent != it)
parent = find_(parent); // not yet, recurse!
if (parent != it) parent = find_(parent); // not yet, recurse!
return parent;
}
@ -73,13 +70,11 @@ protected:
return find_(initial);
}
public:
public:
typedef std::set<KEY> Set;
/// constructor
DSFMap() {
}
DSFMap() {}
/// Given key, find the representative key for the set in which it lives
inline KEY find(const KEY& key) const {
@ -89,12 +84,10 @@ public:
/// Merge two sets
void merge(const KEY& x, const KEY& y) {
// straight from http://en.wikipedia.org/wiki/Disjoint-set_data_structure
iterator xRoot = find_(x);
iterator yRoot = find_(y);
if (xRoot == yRoot)
return;
if (xRoot == yRoot) return;
// Merge sets
if (xRoot->second.rank_ < yRoot->second.rank_)
@ -117,7 +110,14 @@ public:
}
return sets;
}
};
}
/// Small utility class for representing a wrappable pairs of ints.
class IndexPair : public std::pair<size_t,size_t> {
public:
IndexPair(): std::pair<size_t,size_t>(0,0) {}
IndexPair(size_t i, size_t j) : std::pair<size_t,size_t>(i,j) {}
inline size_t i() const { return first; };
inline size_t j() const { return second; };
};
} // namespace gtsam

View File

@ -139,6 +139,12 @@ TEST(DSFMap, sets){
EXPECT(s2 == sets[4]);
}
/* ************************************************************************* */
TEST(DSFMap, findIndexPair) {
DSFMap<IndexPair> dsf;
EXPECT(dsf.find(IndexPair(1,2))==IndexPair(1,2));
EXPECT(dsf.find(IndexPair(1,2)) != dsf.find(IndexPair(1,3)));
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
/* ************************************************************************* */