diff --git a/gtsam_unstable/base/DSFMap.h b/gtsam_unstable/base/DSFMap.h index c9a1293f7..bfc6d4b84 100644 --- a/gtsam_unstable/base/DSFMap.h +++ b/gtsam_unstable/base/DSFMap.h @@ -18,10 +18,10 @@ #pragma once +#include #include #include -#include - +#include namespace gtsam { @@ -33,45 +33,252 @@ namespace gtsam { template class DSFMap { - /// We store the forest in an STL map - typedef std::map Map; - typedef std::set Set; - typedef std::pair key_pair; - mutable Map parent_; +protected: + + /// We store the forest in an STL map + typedef std::map Map; + typedef std::set Set; + typedef std::pair key_pair; + mutable Map parent_; public: - /// constructor - DSFMap() {} + /// constructor + DSFMap() { + } - /// find the label of the set in which {key} lives - KEY find(const KEY& key) const { - typename Map::const_iterator it = parent_.find(key); - // if key does not exist, create and return itself - if (it==parent_.end()) { - parent_[key] = key; - return key; - } else { - // follow parent pointers until we reach set representative - KEY parent = it->second; - if (parent != key) - parent = find(parent); // not yet, recurse! - parent_[key] = parent; // path compression - return parent; - } - } + /// find the label of the set in which {key} lives + KEY find(const KEY& key) const { + typename Map::const_iterator it = parent_.find(key); + // if key does not exist, create and return itself + if (it == parent_.end()) { + parent_[key] = key; + return key; + } else { + // follow parent pointers until we reach set representative + KEY parent = it->second; + if (parent != key) + parent = find(parent); // not yet, recurse! + parent_[key] = parent; // path compression + return parent; + } + } - /// Merge two sets - void merge(const KEY& i1, const KEY& i2) { - parent_[find(i2)] = find(i1); - } + /// Merge two sets + void merge(const KEY& i1, const KEY& i2) { + parent_[find(i2)] = find(i1); + } - /// return all sets, i.e. a partition of all elements - std::map sets() const { - std::map sets; - BOOST_FOREACH(const key_pair& pair, parent_) - sets[find(pair.second)].insert(pair.first); - return sets; - } + /// return all sets, i.e. a partition of all elements + std::map sets() const { + std::map sets; + BOOST_FOREACH(const key_pair& pair, parent_) + sets[find(pair.second)].insert(pair.first); + return sets; + } + +}; + +/** + * Disjoint set forest using an STL map data structure underneath + * Uses rank compression but not union by rank :-( + * @addtogroup base + */ +template +class DSFMapIt { + +protected: + + /// We store the forest in an STL map, but parents are done with pointers + struct Entry { + typedef std::map Map; + typename Map::iterator parent_; + size_t rank_; + Entry() : + rank_(0) { + } + void makeRoot(const typename Map::iterator& it) { + parent_ = it; + } + }; + mutable typename Entry::Map entries_; + + /// find the initial Entry + typename Entry::Map::iterator find__(const KEY& key) const { + typename Entry::Map::iterator it = entries_.find(key); + // if key does not exist, create and return itself + if (it == entries_.end()) { + it = entries_.insert(it, std::make_pair(key, Entry())); + it->second.makeRoot(it); + } + return it; + } + + /// find the root Entry + typename Entry::Map::iterator find_(const KEY& key) const { + typename Entry::Map::iterator initial = find__(key); + // follow parent pointers until we reach set representative + typename Entry::Map::iterator parent = initial->second.parent_; + while (parent->second.parent_ != parent) + parent = parent->second.parent_; // not yet, recurse! + //initial.parent_ = parent; // path compression + return parent; + } + +public: + /// constructor + DSFMapIt() { + } + + /// find the representative KEY for the set in which key lives + KEY find(const KEY& key) const { + typename Entry::Map::iterator root = find_(key); + return root->first; + } + + /// Merge two sets + void merge(const KEY& x, const KEY& y) { + + // straight from http://en.wikipedia.org/wiki/Disjoint-set_data_structure + typename Entry::Map::iterator xRoot = find_(x); + typename Entry::Map::iterator yRoot = find_(y); + if (xRoot == yRoot) + return; + + // Merge sets + size_t xRootRank = xRoot->second.rank_, yRootRank = yRoot->second.rank_; + if (xRootRank < yRootRank) + xRoot->second.parent_ = yRoot; + else if (xRootRank > yRootRank) + yRoot->second.parent_ = xRoot; + else { + yRoot->second.parent_ = xRoot; + xRoot->second.rank_ = xRootRank + 1; + } + } + +}; + +/** + * Disjoint set forest using an STL map data structure underneath + * Uses rank compression but not union by rank :-( + * @addtogroup base + */ +template +class DSFMap2 { + +protected: + + /// We store the forest in an STL map, but parents are done with pointers + struct Entry { + KEY key_; + size_t rank_; + Entry* parent_; + Entry(KEY key) : + key_(key), rank_(0), parent_(0) { + } + void makeRoot() { + parent_ = this; + } + }; + typedef std::map Map; + mutable Map entries_; + + /// find the initial Entry + Entry& find__(const KEY& key) const { + typename Map::iterator it = entries_.find(key); + // if key does not exist, create and return itself + if (it == entries_.end()) { + it = entries_.insert(it, std::make_pair(key, Entry(key))); + it->second.makeRoot(); + } + return it->second; + } + + /// find the root Entry + Entry* find_(const KEY& key) const { + Entry& initial = find__(key); + // follow parent pointers until we reach set representative + Entry* parent = initial.parent_; + while (parent->parent_ != parent) + parent = parent->parent_; // not yet, recurse! + initial.parent_ = parent; // path compression + return parent; + } + +public: + /// constructor + DSFMap2() { + } + + /// find the representative KEY for the set in which key lives + KEY find(const KEY& key) const { + Entry* root = find_(key); + return root->key_; + } + + /// Merge two sets + void merge(const KEY& x, const KEY& y) { + + // straight from http://en.wikipedia.org/wiki/Disjoint-set_data_structure + Entry* xRoot = find_(x); + Entry* yRoot = find_(y); + if (xRoot == yRoot) + return; + + // Merge sets + if (xRoot->rank_ < yRoot->rank_) + xRoot->parent_ = yRoot; + else if (xRoot->rank_ > yRoot->rank_) + yRoot->parent_ = xRoot; + else { + yRoot->parent_ = xRoot; + xRoot->rank_ = xRoot->rank_ + 1; + } + } + +}; + +/** + * DSFMap version that uses union by rank :-) + * @addtogroup base + */ +template +class DSFMap3: public DSFMap { + + /// We store rank in an STL map as well + typedef std::map Ranks; + mutable Ranks rank_; + + size_t rank(const KEY& i) const { + typename Ranks::const_iterator it = rank_.find(i); + return it == rank_.end() ? 0 : it->second; + } + +public: + /// constructor + DSFMap3() { + } + + /// Merge two sets + void merge(const KEY& x, const KEY& y) { + + // straight from http://en.wikipedia.org/wiki/Disjoint-set_data_structure + KEY xRoot = this->find(x); + KEY yRoot = this->find(y); + if (xRoot == yRoot) + return; + + // Merge sets + size_t xRootRank = rank(xRoot), yRootRank = rank(yRoot); + if (xRootRank < yRootRank) + this->parent_[xRoot] = yRoot; + else if (xRootRank > yRootRank) + this->parent_[yRoot] = xRoot; + else { + this->parent_[yRoot] = xRoot; + this->rank_[xRoot] = xRootRank + 1; + } + } }; diff --git a/gtsam_unstable/base/tests/testDSFMap.cpp b/gtsam_unstable/base/tests/testDSFMap.cpp index 23dd2a110..720519f13 100644 --- a/gtsam_unstable/base/tests/testDSFMap.cpp +++ b/gtsam_unstable/base/tests/testDSFMap.cpp @@ -22,9 +22,9 @@ #include #include using namespace boost::assign; -// + #include -// + #include using namespace std; @@ -32,7 +32,7 @@ using namespace gtsam; /* ************************************************************************* */ TEST(DSFMap, find) { - DSFMap dsf; + DSFMapIt dsf; EXPECT(dsf.find(0)==0); EXPECT(dsf.find(2)==2); EXPECT(dsf.find(0)==0); @@ -42,20 +42,21 @@ TEST(DSFMap, find) { /* ************************************************************************* */ TEST(DSFMap, merge) { - DSFMap dsf; + DSFMapIt dsf; dsf.merge(0,2); EXPECT(dsf.find(0) == dsf.find(2)); } + /* ************************************************************************* */ TEST(DSFMap, merge2) { - DSFMap dsf; + DSFMapIt dsf; dsf.merge(2,0); EXPECT(dsf.find(0) == dsf.find(2)); } /* ************************************************************************* */ TEST(DSFMap, merge3) { - DSFMap dsf; + DSFMapIt dsf; dsf.merge(0,1); dsf.merge(1,2); EXPECT(dsf.find(0) == dsf.find(2)); @@ -70,16 +71,14 @@ TEST(DSFMap, mergePairwiseMatches) { matches += Match(1,2), Match(2,3), Match(4,5), Match(4,6); // Merge matches - DSFMap dsf; + DSFMapIt dsf; BOOST_FOREACH(const Match& m, matches) dsf.merge(m.first,m.second); // Each point is now associated with a set, represented by one of its members - size_t rep1 = 1, rep2 = 4; - EXPECT_LONGS_EQUAL(rep1,dsf.find(1)); + size_t rep1 = dsf.find(1), rep2 = dsf.find(4); EXPECT_LONGS_EQUAL(rep1,dsf.find(2)); EXPECT_LONGS_EQUAL(rep1,dsf.find(3)); - EXPECT_LONGS_EQUAL(rep2,dsf.find(4)); EXPECT_LONGS_EQUAL(rep2,dsf.find(5)); EXPECT_LONGS_EQUAL(rep2,dsf.find(6)); } @@ -102,19 +101,17 @@ TEST(DSFMap, mergePairwiseMatches2) { matches += Match(m11,m22), Match(m12,m23), Match(m14,m25), Match(m14,m26); // Merge matches - DSFMap dsf; + DSFMapIt dsf; BOOST_FOREACH(const Match& m, matches) dsf.merge(m.first,m.second); // Check that sets are merged correctly - EXPECT(dsf.find(m11)==m11); - EXPECT(dsf.find(m12)==m12); - EXPECT(dsf.find(m14)==m14); - EXPECT(dsf.find(m22)==m11); - EXPECT(dsf.find(m23)==m12); - EXPECT(dsf.find(m25)==m14); - EXPECT(dsf.find(m26)==m14); + EXPECT(dsf.find(m22)==dsf.find(m11)); + EXPECT(dsf.find(m23)==dsf.find(m12)); + EXPECT(dsf.find(m25)==dsf.find(m14)); + EXPECT(dsf.find(m26)==dsf.find(m14)); } + /* ************************************************************************* */ TEST(DSFMap, sets){ // Create some "matches" @@ -143,6 +140,7 @@ TEST(DSFMap, sets){ EXPECT(s1 == sets[1]); EXPECT(s2 == sets[4]); } + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */ diff --git a/gtsam_unstable/base/tests/timeDSFvariants.cpp b/gtsam_unstable/base/tests/timeDSFvariants.cpp index c2b9ebe2e..938d8dde7 100644 --- a/gtsam_unstable/base/tests/timeDSFvariants.cpp +++ b/gtsam_unstable/base/tests/timeDSFvariants.cpp @@ -45,7 +45,7 @@ int main(int argc, char* argv[]) { // loop over number of images vector ms; - ms += 10, 20, 30, 40, 50, 100, 200, 300, 400, 500; + ms += 10, 20, 30, 40, 50, 100, 200, 300, 400, 500, 1000; BOOST_FOREACH(size_t m,ms) { // We use volatile here to make these appear to the optimizing compiler as // if their values are only known at run-time. @@ -82,7 +82,7 @@ int main(int argc, char* argv[]) { { // DSFMap version timer tim; - DSFMap dsf; + DSFMapIt dsf; BOOST_FOREACH(const Match& m, matches) dsf.merge(m.first, m.second); os << tim.elapsed() << ","; @@ -90,6 +90,16 @@ int main(int argc, char* argv[]) { } { + // DSFMap2 version + timer tim; + DSFMap2 dsf; + BOOST_FOREACH(const Match& m, matches) + dsf.merge(m.first, m.second); + os << tim.elapsed() << endl; + cout << format("DSFMap: %1% s") % tim.elapsed() << endl; + } + + if (false) { // DSF version, functional timer tim; DSF dsf;