diff --git a/.cproject b/.cproject index 6ef68f8f2..55ad7b743 100644 --- a/.cproject +++ b/.cproject @@ -317,7 +317,6 @@ make - clean true true @@ -477,6 +476,7 @@ make + testBayesTree.run true false @@ -484,7 +484,6 @@ make - testSymbolicBayesNet.run true false @@ -492,6 +491,7 @@ make + testSymbolicFactorGraph.run true false @@ -683,6 +683,7 @@ make + testGraph.run true false @@ -738,7 +739,6 @@ make - testSimulated2D.run true false @@ -786,11 +786,20 @@ make + testErrors.run true false true + +make + +testDSF.run +true +true +true + make -j2 diff --git a/cpp/BTree.h b/cpp/BTree.h index 862501bcd..24f45a1af 100644 --- a/cpp/BTree.h +++ b/cpp/BTree.h @@ -224,7 +224,8 @@ namespace gtsam { else if (key < k) node = node->right.root_.get(); else /* (key() == k) */ return node->value(); } - throw std::invalid_argument("BTree::find: key '" + (std::string) k + "' not found"); + //throw std::invalid_argument("BTree::find: key '" + (std::string) k + "' not found"); + throw std::invalid_argument("BTree::find: key not found"); } /** print in-order */ diff --git a/cpp/DSF.h b/cpp/DSF.h new file mode 100644 index 000000000..7d8c2e9f1 --- /dev/null +++ b/cpp/DSF.h @@ -0,0 +1,134 @@ +/* + * DSF.h + * + * Created on: Mar 26, 2010 + * Author: nikai + * Description: An implementation of Disjoint set forests (see CLR page 446 and up) + * Quoting from CLR: A disjoint-set data structure maintains a collection + * S = {S_1,S_2,...} of disjoint dynamic sets. Each set is identified by + * a representative, which is some member of the set. + */ + +#pragma once + +#include +#include +#include +#include +#include "BTree.h" + +namespace gtsam { + + class Symbol; + + template + class DSF : BTree { + + public: + typedef Key Label; // label can be different from key, but for now they are same + typedef DSF Self; + typedef std::set Set; + typedef BTree Tree; + typedef std::pair KeyLabel; + + // constructor + DSF() : Tree() { } + + // constructor + DSF(const Tree& tree) : Tree(tree) {} + + // create a new singleton, does nothing if already exists + Self makeSet(const Key& key) const { if (mem(key)) return *this; else return add(key, key); } + + // find the label of the set in which {key} lives + Label findSet(const Key& key) const { + Key parent = find(key); + return parent == key ? key : findSet(parent); } + + // return a new DSF where x and y are in the same set. Kai: the caml implementation is not const, and I followed + Self makeUnion(const Key& key1, const Key& key2) { return add(findSet_(key2), findSet_(key1)); } + + // create a new singleton with two connected keys + Self makePair(const Key& key1, const Key& key2) const { return makeSet(key1).makeSet(key2).makeUnion(key1, key2); } + + // create a new singleton with a list of fully connected keys + Self makeList(const std::list& keys) const { + Self t = *this; + BOOST_FOREACH(const Key& key, keys) + t = t.makePair(key, keys.front()); + return t; + } + + // return a dsf in which all find_set operations will be O(1) due to path compression. + DSF flatten() const { + DSF t = *this; + BOOST_FOREACH(const KeyLabel& pair, (Tree)t) + t.findSet_(pair.first); + return t; + } + + // maps f over all keys, must be invertible + DSF map(boost::function func) const { + DSF t; + BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) + t.add(func(pair.first), func(pair.second)); + return t; + } + + // return the number of sets + size_t numSets() const { + size_t num = 0; + BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) + if (pair.first == pair.second) num++; + return num; + } + + // return the numer of keys + size_t size() const { return Tree::size(); } + + // return all sets, i.e. a partition of all elements + std::map sets() const { + std::map sets; + BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) + sets[findSet(pair.second)].insert(pair.first); + return sets; + } + + // return a partition of the given elements {keys} + std::map partition(const std::list& keys) const { + std::map partitions; + BOOST_FOREACH(const Key& key, keys) + partitions[findSet(key)].insert(key); + return partitions; + } + + /** equality */ + bool operator==(const Self& t) const { return (Tree)*this == (Tree)t; } + + // print the object + void print(std::string& name = "DSF") const { Tree::print(name); } + + private: + + /** + * same as findSet except with path compression: After we have traversed the path to + * the root, each parent pointer is made to directly point to it + */ + Key findSet_(const Key& key) { + Key parent = find(key); + if (parent == key) + return parent; + else { + Key label = findSet_(parent); + *this = add(key, label); + return label; + } + } + + }; + + // shortcuts + typedef DSF DSFInt; + typedef DSF DSFSymbol; + +} // namespace gtsam diff --git a/cpp/Makefile.am b/cpp/Makefile.am index d760325c2..82614367f 100644 --- a/cpp/Makefile.am +++ b/cpp/Makefile.am @@ -40,13 +40,15 @@ testMatrix_LDADD = libgtsam.la # The header files will be installed in ~/include/gtsam headers = gtsam.h Value.h Testable.h Factor.h Conditional.h headers += Ordering.h IndexTable.h numericalDerivative.h -headers += BTree.h +headers += BTree.h DSF.h sources += Ordering.cpp smallExample.cpp -check_PROGRAMS += testOrdering testBTree +check_PROGRAMS += testOrdering testBTree testDSF testOrdering_SOURCES = testOrdering.cpp testOrdering_LDADD = libgtsam.la testBTree_SOURCES = testBTree.cpp testBTree_LDADD = libgtsam.la +testDSF_SOURCES = testDSF.cpp +testDSF_LDADD = libgtsam.la # Symbolic Inference headers += SymbolicConditional.h diff --git a/cpp/testDSF.cpp b/cpp/testDSF.cpp new file mode 100644 index 000000000..dc5ba300b --- /dev/null +++ b/cpp/testDSF.cpp @@ -0,0 +1,222 @@ +/* + * testDSF.cpp + * + * Created on: Mar 26, 2010 + * Author: nikai + * Description: unit tests for DSF + */ + +#include +#include +#include +using namespace boost::assign; +#include + +#include "DSF.h" + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +TEST(DSF, makeSet) { + DSFInt dsf; + dsf = dsf.makeSet(5); + LONGS_EQUAL(1, dsf.size()); +} + +/* ************************************************************************* */ +TEST(DSF, findSet) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + CHECK(dsf.findSet(5) != dsf.findSet(7)); +} + +/* ************************************************************************* */ +TEST(DSF, makeUnion) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,7); + CHECK(dsf.findSet(5) == dsf.findSet(7)); +} + +/* ************************************************************************* */ +TEST(DSF, makeUnion2) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(7,5); + CHECK(dsf.findSet(5) == dsf.findSet(7)); +} + +/* ************************************************************************* */ +TEST(DSF, makeUnion3) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + dsf = dsf.makeUnion(6,7); + CHECK(dsf.findSet(5) == dsf.findSet(7)); +} + +/* ************************************************************************* */ +TEST(DSF, makePair) { + DSFInt dsf; + dsf = dsf.makePair(0, 1); + dsf = dsf.makePair(1, 2); + dsf = dsf.makePair(3, 2); + CHECK(dsf.findSet(0) == dsf.findSet(3)); +} + +/* ************************************************************************* */ +TEST(DSF, makeList) { + DSFInt dsf; + list keys; keys += 5, 6, 7; + dsf = dsf.makeList(keys); + CHECK(dsf.findSet(5) == dsf.findSet(7)); +} + +/* ************************************************************************* */ +TEST(DSF, numSets) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + LONGS_EQUAL(2, dsf.numSets()); +} + +/* ************************************************************************* */ +TEST(DSF, sets) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeUnion(5,6); + map > sets = dsf.sets(); + LONGS_EQUAL(1, sets.size()); + + set expected; expected += 5, 6; + CHECK(expected == sets[dsf.findSet(5)]); +} + +/* ************************************************************************* */ +TEST(DSF, sets2) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + dsf = dsf.makeUnion(6,7); + map > sets = dsf.sets(); + LONGS_EQUAL(1, sets.size()); + + set expected; expected += 5, 6, 7; + CHECK(expected == sets[dsf.findSet(5)]); +} + +/* ************************************************************************* */ +TEST(DSF, sets3) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + map > sets = dsf.sets(); + LONGS_EQUAL(2, sets.size()); + + set expected; expected += 5, 6; + CHECK(expected == sets[dsf.findSet(5)]); +} + +/* ************************************************************************* */ +TEST(DSF, partition) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeUnion(5,6); + + list keys; keys += 5; + map > partitions = dsf.partition(keys); + LONGS_EQUAL(1, partitions.size()); + + set expected; expected += 5; + CHECK(expected == partitions[dsf.findSet(5)]); +} + +/* ************************************************************************* */ +TEST(DSF, partition2) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + + list keys; keys += 7; + map > partitions = dsf.partition(keys); + LONGS_EQUAL(1, partitions.size()); + + set expected; expected += 7; + CHECK(expected == partitions[dsf.findSet(7)]); +} + +/* ************************************************************************* */ +TEST(DSF, partition3) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + + list keys; keys += 5, 7; + map > partitions = dsf.partition(keys); + LONGS_EQUAL(2, partitions.size()); + + set expected; expected += 5; + CHECK(expected == partitions[dsf.findSet(5)]); +} + +/* ************************************************************************* */ +int func(const int& a) { return a + 10; } +TEST(DSF, map) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + + DSFInt actual = dsf.map(&func); + DSFInt expected; + dsf = dsf.makeSet(15); + dsf = dsf.makeSet(16); + dsf = dsf.makeSet(17); + dsf = dsf.makeUnion(15,16); + CHECK(actual == expected); +} + +/* ************************************************************************* */ +TEST(DSF, flatten) { + DSFInt dsf; + dsf = dsf.makePair(1, 2); + dsf = dsf.makePair(2, 3); + dsf = dsf.makePair(5, 6); + dsf = dsf.makePair(6, 7); + dsf = dsf.makeUnion(2, 6); + + DSFInt actual = dsf.flatten(); + DSFInt expected; + expected = expected.makePair(1, 2); + expected = expected.makePair(1, 3); + expected = expected.makePair(1, 5); + expected = expected.makePair(1, 6); + expected = expected.makePair(1, 7); + CHECK(actual == expected); +} +/* ************************************************************************* */ +int main() { TestResult tr; return TestRegistry::runAllTests(tr);} +/* ************************************************************************* */ +