diff --git a/cpp/FactorGraph.h b/cpp/FactorGraph.h index 7cf78fd68..09ae2f013 100644 --- a/cpp/FactorGraph.h +++ b/cpp/FactorGraph.h @@ -65,7 +65,11 @@ namespace gtsam { /** return the numbers of the factors_ in the factor graph */ inline size_t size() const { - return factors_.size(); + int size_=0; + for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++) + if(*factor != NULL) + size_++; + return size_; } /** Add a factor */ @@ -91,7 +95,7 @@ namespace gtsam { std::cout << s << std::endl; printf("size: %d\n", (int) size()); for (const_iterator factor = factors_.begin(); factor != factors_.end(); factor++) - (*factor)->print(); + if(*factor != NULL) (*factor)->print(); } /** Check equality */ diff --git a/cpp/LinearFactorGraph.cpp b/cpp/LinearFactorGraph.cpp index e5d471748..4d4637ba7 100644 --- a/cpp/LinearFactorGraph.cpp +++ b/cpp/LinearFactorGraph.cpp @@ -57,21 +57,22 @@ list LinearFactorGraph::factors(const string& key) const { /* ************************************************************************* */ -/** O(n) */ +/** find all non-NULL factors for a variable, then set factors to NULL */ /* ************************************************************************* */ -LinearFactorSet -LinearFactorGraph::find_factors_and_remove(const string& key) -{ +LinearFactorSet LinearFactorGraph::find_factors_and_remove(const string& key) { LinearFactorSet found; - for(iterator factor=factors_.begin(); factor!=factors_.end(); ) - if ((*factor)->involves(key)) { - found.push_back(*factor); - factor = factors_.erase(factor); - } else { - factor++; // important, erase will have effect of ++ - } + Indices::iterator it = indices_.find(key); + list *indices_ptr; // pointer to indices list in indices_ map + indices_ptr = &(it->second); + for (list::iterator it = indices_ptr->begin(); it != indices_ptr->end(); it++) { + if(factors_[*it] == NULL){ // skip NULL factors + continue; + } + found.push_back(factors_[*it]); + factors_[*it].reset(); // set factor to NULL. + } return found; } diff --git a/cpp/testLinearFactorGraph.cpp b/cpp/testLinearFactorGraph.cpp index 3eef30000..e29643c0a 100644 --- a/cpp/testLinearFactorGraph.cpp +++ b/cpp/testLinearFactorGraph.cpp @@ -494,6 +494,32 @@ TEST( LinearFactorGraph, find_factors_and_remove ) LONGS_EQUAL(1,fg.size()); } +/* ************************************************************************* */ +TEST( LinearFactorGraph, find_factors_and_remove__twice ) +{ + // create the graph + LinearFactorGraph fg = createLinearFactorGraph(); + + // We expect to remove these three factors: 0, 1, 2 + LinearFactor::shared_ptr f0 = fg[0]; + LinearFactor::shared_ptr f1 = fg[1]; + LinearFactor::shared_ptr f2 = fg[2]; + + // call the function + LinearFactorSet factors = fg.find_factors_and_remove("x1"); + + // Check the factors + CHECK(f0==factors[0]); + CHECK(f1==factors[1]); + CHECK(f2==factors[2]); + + factors = fg.find_factors_and_remove("x1"); + CHECK(factors.size() == 0); + + // CHECK if the factors are deleted from the factor graph + LONGS_EQUAL(1,fg.size()); + } + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */