diff --git a/cpp/FactorGraph-inl.h b/cpp/FactorGraph-inl.h index 49f7c78da..22a8b7a0c 100644 --- a/cpp/FactorGraph-inl.h +++ b/cpp/FactorGraph-inl.h @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include "Ordering.h" #include "FactorGraph.h" @@ -276,13 +276,14 @@ void FactorGraph::associateFactor(int index, sharedFactor factor) { /* ************************************************************************* */ template -vector > FactorGraph::findMinimumSpanningTree() const { +map FactorGraph::findMinimumSpanningTree() const { typedef boost::adjacency_list< boost::vecS, boost::vecS, boost::undirectedS, boost::property, boost::property > Graph; typedef boost::graph_traits::vertex_descriptor Vertex; + typedef boost::graph_traits::vertex_iterator VertexIterator; typedef boost::graph_traits::edge_descriptor Edge; // convert the factor graph to boost graph @@ -316,16 +317,17 @@ vector > FactorGraph::findMinimumSpanningTree() con } // find minimum spanning tree - vector spanning_tree; - boost::kruskal_minimum_spanning_tree(g, back_inserter(spanning_tree)); + vector p_map(boost::num_vertices(g)); + prim_minimum_spanning_tree(g, &p_map[0]); - // convert edge to skin - vector > tree; - for (vector::iterator ei = spanning_tree.begin(); ei - != spanning_tree.end(); ++ei) { - string key1 = boost::get(boost::vertex_name, g, boost::source(*ei,g)); - string key2 = boost::get(boost::vertex_name, g, boost::target(*ei,g)); - tree.push_back(make_pair(key1, key2)); + // convert edge to string pairs + map tree; + VertexIterator itVertex = boost::vertices(g).first; + for (vector::iterator vi = p_map.begin(); vi!=p_map.end(); itVertex++, vi++) { + string key = boost::get(boost::vertex_name, g, *itVertex); + string parent = boost::get(boost::vertex_name, g, *vi); + // printf("%s parent: %s\n", key.c_str(), parent.c_str()); + tree.insert(make_pair(key, parent)); } return tree; diff --git a/cpp/FactorGraph.h b/cpp/FactorGraph.h index 223ca9711..a93e41e8c 100644 --- a/cpp/FactorGraph.h +++ b/cpp/FactorGraph.h @@ -117,9 +117,9 @@ namespace gtsam { std::vector findAndRemoveFactors(const std::string& key); /** - * find the minimum spanning tree + * find the minimum spanning tree. */ - std::vector > findMinimumSpanningTree() const; + std::map findMinimumSpanningTree() const; private: /** Associate factor index with the variables connected to the factor */ diff --git a/cpp/testGaussianFactorGraph.cpp b/cpp/testGaussianFactorGraph.cpp index 7825682ee..5f31de35f 100644 --- a/cpp/testGaussianFactorGraph.cpp +++ b/cpp/testGaussianFactorGraph.cpp @@ -743,8 +743,11 @@ TEST( GaussianFactorGraph, findMinimumSpanningTree ) g.add("x2", I, "x4", I, b, 0); g.add("x3", I, "x4", I, b, 0); - vector > tree = g.findMinimumSpanningTree(); - LONGS_EQUAL(3,tree.size()); + map tree = g.findMinimumSpanningTree(); + CHECK(tree["x1"].compare("x1")==0); + CHECK(tree["x2"].compare("x1")==0); + CHECK(tree["x3"].compare("x1")==0); + CHECK(tree["x4"].compare("x1")==0); } /* ************************************************************************* */