diff --git a/cpp/FactorGraph-inl.h b/cpp/FactorGraph-inl.h index 724296eac..34aa82754 100644 --- a/cpp/FactorGraph-inl.h +++ b/cpp/FactorGraph-inl.h @@ -64,8 +64,8 @@ size_t FactorGraph::nrFactors() const { /* ************************************************************************* */ template void FactorGraph::push_back(shared_factor factor) { - factors_.push_back(factor); // add the actual factor + if (factor==NULL) return; int i = factors_.size() - 1; // index of factor list keys = factor->keys(); // get keys for factor @@ -157,6 +157,15 @@ Ordering FactorGraph::getOrdering() const { return colamd(n_col, n_row, nrNonZeros, columns); } +/* ************************************************************************* */ +/** O(1) */ +/* ************************************************************************* */ +template +list FactorGraph::factors(const string& key) const { + Indices::const_iterator it = indices_.find(key); + return it->second; +} + /* ************************************************************************* */ /** find all non-NULL factors for a variable, then set factors to NULL */ /* ************************************************************************* */ diff --git a/cpp/FactorGraph.h b/cpp/FactorGraph.h index 77b94e6b3..e39fe1448 100644 --- a/cpp/FactorGraph.h +++ b/cpp/FactorGraph.h @@ -82,6 +82,12 @@ namespace gtsam { */ Ordering getOrdering() const; + /** + * Return indices for all factors that involve the given node + * @param key the key for the given node + */ + std::list factors(const std::string& key) const; + /** * find all the factors that involve the given node and remove them * from the factor graph diff --git a/cpp/LinearFactorGraph.cpp b/cpp/LinearFactorGraph.cpp index 63052716b..2e4d9147b 100644 --- a/cpp/LinearFactorGraph.cpp +++ b/cpp/LinearFactorGraph.cpp @@ -47,14 +47,6 @@ set LinearFactorGraph::find_separator(const string& key) const return separator; } -/* ************************************************************************* */ -/** O(1) */ -/* ************************************************************************* */ -list LinearFactorGraph::factors(const string& key) const { - Indices::const_iterator it = indices_.find(key); - return it->second; -} - /* ************************************************************************* */ /* find factors and remove them from the factor graph: O(n) */ /* ************************************************************************* */ diff --git a/cpp/LinearFactorGraph.h b/cpp/LinearFactorGraph.h index 1ce89df55..e3e7668c2 100644 --- a/cpp/LinearFactorGraph.h +++ b/cpp/LinearFactorGraph.h @@ -67,12 +67,6 @@ namespace gtsam { */ std::set find_separator(const std::string& key) const; - /** - * Return indices for all factors that involve the given node - * @param key the key for the given node - */ - std::list factors(const std::string& key) const; - /** * extract and combine all the factors that involve a given node * NOTE: the combined factor will be depends on a system-dependent diff --git a/cpp/testSymbolicBayesChain.cpp b/cpp/testSymbolicBayesChain.cpp index 8527e319d..334655176 100644 --- a/cpp/testSymbolicBayesChain.cpp +++ b/cpp/testSymbolicBayesChain.cpp @@ -72,7 +72,7 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( SymbolicBayesChain, symbolicFactorGraph ) +TEST( SymbolicFactorGraph, symbolicFactorGraph ) { // construct expected symbolic graph SymbolicFactorGraph expected; @@ -98,8 +98,46 @@ TEST( SymbolicBayesChain, symbolicFactorGraph ) SymbolicFactorGraph actual(factorGraph); CHECK(assert_equal(expected, actual)); +} - //symbolicGraph.find_factors_and_remove("x"); +/* ************************************************************************* */ +TEST( SymbolicFactorGraph, find_factors_and_remove ) +{ + // construct it from the factor graph graph + LinearFactorGraph factorGraph = createLinearFactorGraph(); + SymbolicFactorGraph actual(factorGraph); + SymbolicFactor::shared_ptr f1 = actual[0]; + SymbolicFactor::shared_ptr f3 = actual[2]; + actual.find_factors_and_remove("x2"); + + // construct expected graph after find_factors_and_remove + SymbolicFactorGraph expected; + SymbolicFactor::shared_ptr null; + expected.push_back(f1); + expected.push_back(null); + expected.push_back(f3); + expected.push_back(null); + + CHECK(assert_equal(expected, actual)); +} +/* ************************************************************************* */ +TEST( SymbolicFactorGraph, factor_lookup) +{ + // create a test graph + LinearFactorGraph factorGraph = createLinearFactorGraph(); + SymbolicFactorGraph fg(factorGraph); + + // ask for all factor indices connected to x1 + list x1_factors = fg.factors("x1"); + int x1_indices[] = { 0, 1, 2 }; + list x1_expected(x1_indices, x1_indices + 3); + CHECK(x1_factors==x1_expected); + + // ask for all factor indices connected to x2 + list x2_factors = fg.factors("x2"); + int x2_indices[] = { 1, 3 }; + list x2_expected(x2_indices, x2_indices + 2); + CHECK(x2_factors==x2_expected); } /* ************************************************************************* */