add split to FactorGraph
parent
5dc237eeea
commit
06b7f8ee04
|
@ -289,7 +289,7 @@ map<string, string> FactorGraph<Factor>::findMinimumSpanningTree() const {
|
||||||
Vertex v1, v2;
|
Vertex v1, v2;
|
||||||
BOOST_FOREACH(sharedFactor factor, factors_){
|
BOOST_FOREACH(sharedFactor factor, factors_){
|
||||||
if (factor->keys().size() > 2)
|
if (factor->keys().size() > 2)
|
||||||
throw(invalid_argument("findMinimumSpanningTree: only support factors with two keys"));
|
throw(invalid_argument("findMinimumSpanningTree: only support factors with at most two keys"));
|
||||||
|
|
||||||
if (factor->keys().size() == 1)
|
if (factor->keys().size() == 1)
|
||||||
continue;
|
continue;
|
||||||
|
@ -330,6 +330,31 @@ map<string, string> FactorGraph<Factor>::findMinimumSpanningTree() const {
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class Factor>
|
||||||
|
pair<FactorGraph<Factor>, FactorGraph<Factor> > FactorGraph<Factor>::split(map<string, string> tree) const {
|
||||||
|
|
||||||
|
FactorGraph<Factor> Ab1, Ab2;
|
||||||
|
BOOST_FOREACH(sharedFactor factor, factors_){
|
||||||
|
if (factor->keys().size() > 2)
|
||||||
|
throw(invalid_argument("split: only support factors with at most two keys"));
|
||||||
|
|
||||||
|
if (factor->keys().size() == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
string key1 = factor->keys().front();
|
||||||
|
string key2 = factor->keys().back();
|
||||||
|
// if the tree contains the key
|
||||||
|
if (tree.find(key1) != tree.end() && tree[key1].compare(key2) == 0 ||
|
||||||
|
tree.find(key2) != tree.end() && tree[key2].compare(key1) == 0)
|
||||||
|
Ab1.push_back(factor);
|
||||||
|
else
|
||||||
|
Ab2.push_back(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
return make_pair(Ab1, Ab2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/* find factors and remove them from the factor graph: O(n) */
|
/* find factors and remove them from the factor graph: O(n) */
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -121,6 +121,12 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
std::map<std::string, std::string> findMinimumSpanningTree() const;
|
std::map<std::string, std::string> findMinimumSpanningTree() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Split the graph into two parts: one corresponds to the given spanning tre,
|
||||||
|
* and the other corresponds to the rest of the factors
|
||||||
|
*/
|
||||||
|
std::pair<FactorGraph<Factor>, FactorGraph<Factor> > split(std::map<std::string, std::string> tree) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Associate factor index with the variables connected to the factor */
|
/** Associate factor index with the variables connected to the factor */
|
||||||
void associateFactor(int index, sharedFactor factor);
|
void associateFactor(int index, sharedFactor factor);
|
||||||
|
|
|
@ -750,6 +750,32 @@ TEST( GaussianFactorGraph, findMinimumSpanningTree )
|
||||||
CHECK(tree["x4"].compare("x1")==0);
|
CHECK(tree["x4"].compare("x1")==0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST( GaussianFactorGraph, split )
|
||||||
|
{
|
||||||
|
GaussianFactorGraph g;
|
||||||
|
Matrix I = eye(2);
|
||||||
|
Vector b = Vector_(0, 0, 0);
|
||||||
|
g.add("x1", I, "x2", I, b, 0);
|
||||||
|
g.add("x1", I, "x3", I, b, 0);
|
||||||
|
g.add("x1", I, "x4", I, b, 0);
|
||||||
|
g.add("x2", I, "x3", I, b, 0);
|
||||||
|
g.add("x2", I, "x4", I, b, 0);
|
||||||
|
|
||||||
|
map<string, string> tree;
|
||||||
|
tree["x1"] = "x1";
|
||||||
|
tree["x2"] = "x1";
|
||||||
|
tree["x3"] = "x1";
|
||||||
|
tree["x4"] = "x1";
|
||||||
|
|
||||||
|
GaussianFactorGraph Ab1, Ab2;
|
||||||
|
pair<FactorGraph<GaussianFactor>, FactorGraph<GaussianFactor> > gg = g.split(tree);
|
||||||
|
Ab1 = *reinterpret_cast<GaussianFactorGraph*>(&(gg.first));
|
||||||
|
Ab2 = *reinterpret_cast<GaussianFactorGraph*>(&(gg.second));
|
||||||
|
LONGS_EQUAL(3, Ab1.size());
|
||||||
|
LONGS_EQUAL(2, Ab2.size());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue