added a vector-based DSF implmentation

release/4.3a0
Kai Ni 2010-06-25 06:35:44 +00:00
parent ef92f1b365
commit 2d40df17ac
7 changed files with 230 additions and 11 deletions

View File

@ -317,6 +317,7 @@
</target> </target>
<target name="clean" path="wrap" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="clean" path="wrap" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>clean</buildTarget> <buildTarget>clean</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand> <useDefaultCommand>true</useDefaultCommand>
@ -476,7 +477,6 @@
</target> </target>
<target name="testBayesTree.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testBayesTree.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testBayesTree.run</buildTarget> <buildTarget>testBayesTree.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -484,6 +484,7 @@
</target> </target>
<target name="testSymbolicBayesNet.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSymbolicBayesNet.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testSymbolicBayesNet.run</buildTarget> <buildTarget>testSymbolicBayesNet.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -491,7 +492,6 @@
</target> </target>
<target name="testSymbolicFactorGraph.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSymbolicFactorGraph.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testSymbolicFactorGraph.run</buildTarget> <buildTarget>testSymbolicFactorGraph.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -683,7 +683,6 @@
</target> </target>
<target name="testGraph.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testGraph.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testGraph.run</buildTarget> <buildTarget>testGraph.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -739,6 +738,7 @@
</target> </target>
<target name="testSimulated2D.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSimulated2D.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testSimulated2D.run</buildTarget> <buildTarget>testSimulated2D.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -786,7 +786,6 @@
</target> </target>
<target name="testErrors.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testErrors.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testErrors.run</buildTarget> <buildTarget>testErrors.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
@ -794,7 +793,6 @@
</target> </target>
<target name="testDSF.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testDSF.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testDSF.run</buildTarget> <buildTarget>testDSF.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand> <useDefaultCommand>true</useDefaultCommand>
@ -810,7 +808,6 @@
</target> </target>
<target name="testConstraintOptimizer.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testConstraintOptimizer.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testConstraintOptimizer.run</buildTarget> <buildTarget>testConstraintOptimizer.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand> <useDefaultCommand>true</useDefaultCommand>
@ -818,6 +815,7 @@
</target> </target>
<target name="testBTree.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testBTree.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testBTree.run</buildTarget> <buildTarget>testBTree.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand> <useDefaultCommand>true</useDefaultCommand>
@ -825,12 +823,19 @@
</target> </target>
<target name="testSimulated2DOriented.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="testSimulated2DOriented.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testSimulated2DOriented.run</buildTarget> <buildTarget>testSimulated2DOriented.run</buildTarget>
<stopOnError>true</stopOnError> <stopOnError>true</stopOnError>
<useDefaultCommand>false</useDefaultCommand> <useDefaultCommand>false</useDefaultCommand>
<runAllBuilders>true</runAllBuilders> <runAllBuilders>true</runAllBuilders>
</target> </target>
<target name="testDSFVector.run" path="cpp" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand>
<buildArguments/>
<buildTarget>testDSFVector.run</buildTarget>
<stopOnError>true</stopOnError>
<useDefaultCommand>true</useDefaultCommand>
<runAllBuilders>true</runAllBuilders>
</target>
<target name="install" path="" targetID="org.eclipse.cdt.build.MakeTargetBuilder"> <target name="install" path="" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
<buildCommand>make</buildCommand> <buildCommand>make</buildCommand>
<buildArguments>-j2</buildArguments> <buildArguments>-j2</buildArguments>

View File

@ -110,7 +110,7 @@ namespace gtsam {
} }
// get the nodes in the tree with the given label // get the nodes in the tree with the given label
Set set(const Label& label) { Set set(const Label& label) const {
Set set; Set set;
BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) { BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) {
if (pair.second == label || findSet(pair.second) == label) if (pair.second == label || findSet(pair.second) == label)

58
cpp/DSFVector.cpp Normal file
View File

@ -0,0 +1,58 @@
/*
* DSFVector.cpp
*
* Created on: Jun 25, 2010
* Author: nikai
* Description: a faster implementation for DSF, which uses vector rather than btree.
* As a result, the size of the forest is prefixed.
*/
#include "DSFVector.h"
using namespace std;
namespace gtsam {
/* ************************************************************************* */
DSFVector::DSFVector (const size_t numNodes) {
resize(numNodes);
int index = 0;
for(iterator it = begin(); it!=end(); it++, index++)
*it = index;
}
/* ************************************************************************* */
size_t DSFVector::findSet(const size_t& key) const {
size_t parent = at(key);
return parent == key ? key : findSet(parent);
}
/* ************************************************************************* */
std::set<size_t> DSFVector::set(const std::size_t& label) const {
std::set<size_t> set;
size_t key = 0;
std::vector<std::size_t>::const_iterator it = begin();
for (; it != end(); it++, key++) {
if (*it == label || findSet(*it) == label)
set.insert(key);
}
return set;
}
/* ************************************************************************* */
void DSFVector::makeUnionInPlace(const std::size_t& i1, const std::size_t& i2) {
at(findSet(i2)) = findSet(i1);
}
/* ************************************************************************* */
std::map<size_t, std::set<size_t> > DSFVector::sets() const {
std::map<size_t, std::set<size_t> > sets;
size_t key = 0;
std::vector<std::size_t>::const_iterator it = begin();
for (; it != end(); it++, key++) {
sets[findSet(*it)].insert(key);
}
return sets;
}
}

41
cpp/DSFVector.h Normal file
View File

@ -0,0 +1,41 @@
/*
* DSFVector.h
*
* Created on: Jun 25, 2010
* Author: nikai
* Description: a faster implementation for DSF, which uses vector rather than btree.
* As a result, the size of the forest is prefixed.
*/
#pragma once
#include <vector>
#include <map>
#include <set>
namespace gtsam {
/**
* A fast impelementation of disjoint set forests
*/
class DSFVector : protected std::vector<std::size_t> {
private:
public:
// constructor
DSFVector(const std::size_t numNodes);
// find the label of the set in which {key} lives
size_t findSet(const size_t& key) const;
// the in-place version of makeUnion
void makeUnionInPlace(const std::size_t& i1, const std::size_t& i2);
// get the nodes in the tree with the given label
std::set<size_t> set(const std::size_t& label) const;
// return all sets, i.e. a partition of all elements
std::map<size_t, std::set<size_t> > sets() const;
};
}

View File

@ -40,15 +40,17 @@ testMatrix_LDADD = libgtsam.la
# The header files will be installed in ~/include/gtsam # The header files will be installed in ~/include/gtsam
headers = gtsam.h Value.h Testable.h Factor.h Conditional.h headers = gtsam.h Value.h Testable.h Factor.h Conditional.h
headers += Ordering.h IndexTable.h numericalDerivative.h headers += Ordering.h IndexTable.h numericalDerivative.h
headers += BTree.h DSF.h headers += BTree.h DSF.h DSFVector.h
sources += Ordering.cpp smallExample.cpp sources += Ordering.cpp smallExample.cpp DSFVector.cpp
check_PROGRAMS += testOrdering testBTree testDSF check_PROGRAMS += testOrdering testBTree testDSF testDSFVector
testOrdering_SOURCES = testOrdering.cpp testOrdering_SOURCES = testOrdering.cpp
testOrdering_LDADD = libgtsam.la testOrdering_LDADD = libgtsam.la
testBTree_SOURCES = testBTree.cpp testBTree_SOURCES = testBTree.cpp
testBTree_LDADD = libgtsam.la testBTree_LDADD = libgtsam.la
testDSF_SOURCES = testDSF.cpp testDSF_SOURCES = testDSF.cpp
testDSF_LDADD = libgtsam.la testDSF_LDADD = libgtsam.la
testDSFVector_SOURCES = testDSFVector.cpp
testDSFVector_LDADD = libgtsam.la
# Symbolic Inference # Symbolic Inference
headers += SymbolicConditional.h headers += SymbolicConditional.h

View File

@ -99,6 +99,10 @@ namespace gtsam {
return keys_.empty(); return keys_.empty();
} }
/** get the size of the factor */
std::size_t size() const {
return keys_.size();
}
}; };

109
cpp/testDSFVector.cpp Normal file
View File

@ -0,0 +1,109 @@
/*
* testDSF.cpp
*
* Created on: June 25, 2010
* Author: nikai
* Description: unit tests for DSF
*/
#include <iostream>
#include <boost/assign/std/list.hpp>
#include <boost/assign/std/set.hpp>
using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include "DSFVector.h"
#include "Key.h"
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST(DSFVectorVector, findSet) {
DSFVector dsf(3);
CHECK(dsf.findSet(0) != dsf.findSet(2));
}
/* ************************************************************************* */
TEST(DSFVectorVector, makeUnionInPlace) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,2);
CHECK(dsf.findSet(0) == dsf.findSet(2));
}
/* ************************************************************************* */
TEST(DSFVector, makeUnion2) {
DSFVector dsf(3);
dsf.makeUnionInPlace(2,0);
CHECK(dsf.findSet(0) == dsf.findSet(2));
}
/* ************************************************************************* */
TEST(DSFVector, makeUnion3) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,1);
dsf.makeUnionInPlace(1,2);
CHECK(dsf.findSet(0) == dsf.findSet(2));
}
/* ************************************************************************* */
TEST(DSFVector, sets) {
DSFVector dsf(2);
dsf.makeUnionInPlace(0,1);
map<size_t, set<size_t> > sets = dsf.sets();
LONGS_EQUAL(1, sets.size());
set<size_t> expected; expected += 0, 1;
CHECK(expected == sets[dsf.findSet(0)]);
}
/* ************************************************************************* */
TEST(DSFVector, sets2) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,1);
dsf.makeUnionInPlace(1,2);
map<size_t, set<size_t> > sets = dsf.sets();
LONGS_EQUAL(1, sets.size());
set<size_t> expected; expected += 0, 1, 2;
CHECK(expected == sets[dsf.findSet(0)]);
}
/* ************************************************************************* */
TEST(DSFVector, sets3) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,1);
map<size_t, set<size_t> > sets = dsf.sets();
LONGS_EQUAL(2, sets.size());
set<size_t> expected; expected += 0, 1;
CHECK(expected == sets[dsf.findSet(0)]);
}
/* ************************************************************************* */
TEST(DSFVector, set) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,1);
set<size_t> set = dsf.set(0);
LONGS_EQUAL(2, set.size());
std::set<size_t> expected; expected += 0, 1;
CHECK(expected == set);
}
/* ************************************************************************* */
TEST(DSFVector, set2) {
DSFVector dsf(3);
dsf.makeUnionInPlace(0,1);
dsf.makeUnionInPlace(1,2);
set<size_t> set = dsf.set(0);
LONGS_EQUAL(3, set.size());
std::set<size_t> expected; expected += 0, 1, 2;
CHECK(expected == set);
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
/* ************************************************************************* */