diff --git a/cpp/FactorGraph-inl.h b/cpp/FactorGraph-inl.h index d98087c7f..af20cc8e6 100644 --- a/cpp/FactorGraph-inl.h +++ b/cpp/FactorGraph-inl.h @@ -289,7 +289,7 @@ map FactorGraph::findMinimumSpanningTree() const { Vertex v1, v2; BOOST_FOREACH(sharedFactor factor, factors_){ if (factor->keys().size() > 2) - throw(invalid_argument("findMinimumSpanningTree: only support factors with two keys")); + throw(invalid_argument("findMinimumSpanningTree: only support factors with at most two keys")); if (factor->keys().size() == 1) continue; @@ -330,6 +330,31 @@ map FactorGraph::findMinimumSpanningTree() const { return tree; } +template +pair, FactorGraph > FactorGraph::split(map tree) const { + + FactorGraph Ab1, Ab2; + BOOST_FOREACH(sharedFactor factor, factors_){ + if (factor->keys().size() > 2) + throw(invalid_argument("split: only support factors with at most two keys")); + + if (factor->keys().size() == 1) + continue; + + string key1 = factor->keys().front(); + string key2 = factor->keys().back(); + // if the tree contains the key + if (tree.find(key1) != tree.end() && tree[key1].compare(key2) == 0 || + tree.find(key2) != tree.end() && tree[key2].compare(key1) == 0) + Ab1.push_back(factor); + else + Ab2.push_back(factor); + } + + return make_pair(Ab1, Ab2); +} + + /* ************************************************************************* */ /* find factors and remove them from the factor graph: O(n) */ /* ************************************************************************* */ diff --git a/cpp/FactorGraph.h b/cpp/FactorGraph.h index a93e41e8c..518c33c82 100644 --- a/cpp/FactorGraph.h +++ b/cpp/FactorGraph.h @@ -121,6 +121,12 @@ namespace gtsam { */ std::map findMinimumSpanningTree() const; + /** + * Split the graph into two parts: one corresponds to the given spanning tre, + * and the other corresponds to the rest of the factors + */ + std::pair, FactorGraph > split(std::map tree) const; + private: /** Associate factor index with the variables connected to the factor */ void associateFactor(int index, sharedFactor factor); diff --git a/cpp/testGaussianFactorGraph.cpp b/cpp/testGaussianFactorGraph.cpp index 5f31de35f..60164ec80 100644 --- a/cpp/testGaussianFactorGraph.cpp +++ b/cpp/testGaussianFactorGraph.cpp @@ -750,6 +750,32 @@ TEST( GaussianFactorGraph, findMinimumSpanningTree ) CHECK(tree["x4"].compare("x1")==0); } +/* ************************************************************************* */ +TEST( GaussianFactorGraph, split ) +{ + GaussianFactorGraph g; + Matrix I = eye(2); + Vector b = Vector_(0, 0, 0); + g.add("x1", I, "x2", I, b, 0); + g.add("x1", I, "x3", I, b, 0); + g.add("x1", I, "x4", I, b, 0); + g.add("x2", I, "x3", I, b, 0); + g.add("x2", I, "x4", I, b, 0); + + map tree; + tree["x1"] = "x1"; + tree["x2"] = "x1"; + tree["x3"] = "x1"; + tree["x4"] = "x1"; + + GaussianFactorGraph Ab1, Ab2; + pair, FactorGraph > gg = g.split(tree); + Ab1 = *reinterpret_cast(&(gg.first)); + Ab2 = *reinterpret_cast(&(gg.second)); + LONGS_EQUAL(3, Ab1.size()); + LONGS_EQUAL(2, Ab2.size()); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */