add findMinimumSpanningTree to FactorGraph
parent
cd644e75a5
commit
9845a5ae37
|
@ -18,10 +18,16 @@
|
||||||
#include <boost/foreach.hpp>
|
#include <boost/foreach.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <boost/graph/graph_traits.hpp>
|
||||||
|
#include <boost/graph/adjacency_list.hpp>
|
||||||
|
#include <boost/graph/kruskal_min_spanning_tree.hpp>
|
||||||
#include <colamd/colamd.h>
|
#include <colamd/colamd.h>
|
||||||
#include "Ordering.h"
|
#include "Ordering.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// trick from some reading group
|
// trick from some reading group
|
||||||
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
|
#define FOREACH_PAIR( KEY, VAL, COL) BOOST_FOREACH (boost::tie(KEY,VAL),COL)
|
||||||
|
|
||||||
|
@ -268,6 +274,63 @@ void FactorGraph<Factor>::associateFactor(int index, sharedFactor factor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class Factor>
|
||||||
|
vector<pair<string, string> > FactorGraph<Factor>::findMinimumSpanningTree() const {
|
||||||
|
|
||||||
|
typedef boost::adjacency_list<
|
||||||
|
boost::vecS, boost::vecS, boost::undirectedS,
|
||||||
|
boost::property<boost::vertex_name_t, string>,
|
||||||
|
boost::property<boost::edge_weight_t, int> > Graph;
|
||||||
|
typedef boost::graph_traits<Graph>::vertex_descriptor Vertex;
|
||||||
|
typedef boost::graph_traits<Graph>::edge_descriptor Edge;
|
||||||
|
|
||||||
|
// convert the factor graph to boost graph
|
||||||
|
Graph g(0);
|
||||||
|
map<string, Vertex> key2vertex;
|
||||||
|
Vertex v1, v2;
|
||||||
|
BOOST_FOREACH(sharedFactor factor, factors_){
|
||||||
|
if (factor->keys().size() > 2)
|
||||||
|
throw(invalid_argument("findMinimumSpanningTree: only support factors with two keys"));
|
||||||
|
|
||||||
|
if (factor->keys().size() == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
string key1 = factor->keys().front();
|
||||||
|
string key2 = factor->keys().back();
|
||||||
|
|
||||||
|
if (key2vertex.find(key1) == key2vertex.end()) {
|
||||||
|
v1 = add_vertex(key1, g);
|
||||||
|
key2vertex.insert(make_pair(key1, v1));
|
||||||
|
} else
|
||||||
|
v1 = key2vertex[key1];
|
||||||
|
|
||||||
|
if (key2vertex.find(key2) == key2vertex.end()) {
|
||||||
|
v2 = add_vertex(key2, g);
|
||||||
|
key2vertex.insert(make_pair(key2, v2));
|
||||||
|
} else
|
||||||
|
v2 = key2vertex[key2];
|
||||||
|
|
||||||
|
boost::property<boost::edge_weight_t, int> edge_property(1); // assume constant edge weight here
|
||||||
|
boost::add_edge(v1, v2, edge_property, g);
|
||||||
|
}
|
||||||
|
|
||||||
|
// find minimum spanning tree
|
||||||
|
vector<Edge> spanning_tree;
|
||||||
|
boost::kruskal_minimum_spanning_tree(g, back_inserter(spanning_tree));
|
||||||
|
|
||||||
|
// convert edge to skin
|
||||||
|
vector<pair<string, string> > tree;
|
||||||
|
for (vector<Edge>::iterator ei = spanning_tree.begin(); ei
|
||||||
|
!= spanning_tree.end(); ++ei) {
|
||||||
|
string key1 = boost::get(boost::vertex_name, g, boost::source(*ei,g));
|
||||||
|
string key2 = boost::get(boost::vertex_name, g, boost::target(*ei,g));
|
||||||
|
tree.push_back(make_pair(key1, key2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/* find factors and remove them from the factor graph: O(n) */
|
/* find factors and remove them from the factor graph: O(n) */
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -116,6 +116,11 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
std::vector<sharedFactor> findAndRemoveFactors(const std::string& key);
|
std::vector<sharedFactor> findAndRemoveFactors(const std::string& key);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* find the minimum spanning tree
|
||||||
|
*/
|
||||||
|
std::vector<std::pair<std::string, std::string> > findMinimumSpanningTree() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Associate factor index with the variables connected to the factor */
|
/** Associate factor index with the variables connected to the factor */
|
||||||
void associateFactor(int index, sharedFactor factor);
|
void associateFactor(int index, sharedFactor factor);
|
||||||
|
|
|
@ -319,4 +319,3 @@ boost::shared_ptr<VectorConfig> GaussianFactorGraph::conjugateGradientDescent_(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,7 @@ namespace gtsam {
|
||||||
boost::shared_ptr<VectorConfig> conjugateGradientDescent_(
|
boost::shared_ptr<VectorConfig> conjugateGradientDescent_(
|
||||||
const VectorConfig& x0, bool verbose = false, double epsilon = 1e-3,
|
const VectorConfig& x0, bool verbose = false, double epsilon = 1e-3,
|
||||||
size_t maxIterations = 0) const;
|
size_t maxIterations = 0) const;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -730,6 +730,23 @@ TEST( GaussianFactorGraph, constrained_multi2 )
|
||||||
CHECK(assert_equal(expected, actual));
|
CHECK(assert_equal(expected, actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( GaussianFactorGraph, findMinimumSpanningTree )
|
||||||
|
{
|
||||||
|
GaussianFactorGraph g;
|
||||||
|
Matrix I = eye(2);
|
||||||
|
Vector b = Vector_(0, 0, 0);
|
||||||
|
g.add("x1", I, "x2", I, b, 0);
|
||||||
|
g.add("x1", I, "x3", I, b, 0);
|
||||||
|
g.add("x1", I, "x4", I, b, 0);
|
||||||
|
g.add("x2", I, "x3", I, b, 0);
|
||||||
|
g.add("x2", I, "x4", I, b, 0);
|
||||||
|
g.add("x3", I, "x4", I, b, 0);
|
||||||
|
|
||||||
|
vector<pair<string, string> > tree = g.findMinimumSpanningTree();
|
||||||
|
LONGS_EQUAL(3,tree.size());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue