From dfb1e21284b7e9daac81308c3b1c0755ebac6006 Mon Sep 17 00:00:00 2001 From: Richard Roberts Date: Tue, 3 Jan 2012 17:49:29 +0000 Subject: [PATCH] Function to remove factors from a VariableIndex --- gtsam/inference/VariableIndex.cpp | 4 +- gtsam/inference/VariableIndex.h | 53 ++++++++++++++++----- gtsam/inference/tests/testVariableIndex.cpp | 32 +++++++++++++ 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/gtsam/inference/VariableIndex.cpp b/gtsam/inference/VariableIndex.cpp index 88e037434..b44fa7f9a 100644 --- a/gtsam/inference/VariableIndex.cpp +++ b/gtsam/inference/VariableIndex.cpp @@ -16,6 +16,7 @@ */ #include + #include namespace gtsam { @@ -49,7 +50,8 @@ bool VariableIndex::equals(const VariableIndex& other, double tol) const { /* ************************************************************************* */ void VariableIndex::print(const std::string& str) const { - std::cout << str; + std::cout << str << "\n"; + std::cout << "nEntries = " << this->nEntries_ << ", nFactors = " << this->nFactors_ << "\n"; Index var = 0; BOOST_FOREACH(const Factors& variable, index_.container()) { Permutation::const_iterator rvar = find(index_.permutation().begin(), index_.permutation().end(), var); diff --git a/gtsam/inference/VariableIndex.h b/gtsam/inference/VariableIndex.h index 0c9461bfe..a7d6f9ba8 100644 --- a/gtsam/inference/VariableIndex.h +++ b/gtsam/inference/VariableIndex.h @@ -17,12 +17,13 @@ #pragma once +#include +#include +#include + #include #include -#include -#include - namespace gtsam { class Inference; @@ -94,7 +95,15 @@ public: * Augment the variable index with new factors. This can be used when * solving problems incrementally. */ - template void augment(const FactorGraph& factorGraph); + template void augment(const FactorGraph& factors); + + /** + * Remove entries corresponding to the specified factors. + * @param indices The indices of the factors to remove, which must match \c factors + * @param factors The factors being removed, which must symbolically correspond + * exactly to the factors with the specified \c indices that were added. + */ + template void remove(const CONTAINER& indices, const FactorGraph& factors); /** Test for equality (for unit tests and debug assertions). */ bool equals(const VariableIndex& other, double tol=0.0) const; @@ -119,7 +128,7 @@ template void VariableIndex::fill(const FactorGraph& factorGraph) { // Build index mapping from variable id to factor index - for(size_t fi=0; fikeys()) { if(key < index_.size()) { @@ -127,8 +136,9 @@ void VariableIndex::fill(const FactorGraph& factorGraph) { ++ nEntries_; } } - ++ nFactors_; } + ++ nFactors_; // Increment factor count even if factors are null, to keep indices consistent + } } /* ************************************************************************* */ @@ -167,13 +177,13 @@ VariableIndex::VariableIndex(const FactorGraph& factorGraph, Index nVariables) : /* ************************************************************************* */ template -void VariableIndex::augment(const FactorGraph& factorGraph) { +void VariableIndex::augment(const FactorGraph& factors) { // If the factor graph is empty, return an empty index because inside this // if block we assume at least one factor. - if(factorGraph.size() > 0) { + if(factors.size() > 0) { // Find highest-numbered variable Index maxVar = 0; - BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, factorGraph) { + BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, factors) { if(factor) { BOOST_FOREACH(const Index key, factor->keys()) { if(key > maxVar) @@ -191,15 +201,32 @@ void VariableIndex::augment(const FactorGraph& factorGraph) { // Augment index mapping from variable id to factor index size_t orignFactors = nFactors_; - for(size_t fi=0; fikeys()) { + for(size_t fi=0; fikeys()) { index_[key].push_back(orignFactors + fi); ++ nEntries_; } - ++ nFactors_; } + ++ nFactors_; // Increment factor count even if factors are null, to keep indices consistent + } } } +/* ************************************************************************* */ +template +void VariableIndex::remove(const CONTAINER& indices, const FactorGraph& factors) { + for(size_t fi=0; fikeys().size(); ++ji) { + Factors& factorEntries = index_[factors[fi]->keys()[ji]]; + Factors::iterator entry = std::find(factorEntries.begin(), factorEntries.end(), indices[fi]); + if(entry == factorEntries.end()) + throw std::invalid_argument("Internal error, indices and factors passed into VariableIndex::remove are not consistent with the existing variable index"); + factorEntries.erase(entry); + -- nEntries_; + } + } +} + } diff --git a/gtsam/inference/tests/testVariableIndex.cpp b/gtsam/inference/tests/testVariableIndex.cpp index 8e52e2176..2ffc03a97 100644 --- a/gtsam/inference/tests/testVariableIndex.cpp +++ b/gtsam/inference/tests/testVariableIndex.cpp @@ -46,6 +46,38 @@ TEST(VariableIndex, augment) { CHECK(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(VariableIndex, remove) { + + SymbolicFactorGraph fg1, fg2; + fg1.push_factor(0, 1); + fg1.push_factor(0, 2); + fg1.push_factor(5, 9); + fg1.push_factor(2, 3); + fg2.push_factor(1, 3); + fg2.push_factor(2, 4); + fg2.push_factor(3, 5); + fg2.push_factor(5, 6); + + SymbolicFactorGraph fgCombined; fgCombined.push_back(fg1); fgCombined.push_back(fg2); + + // Create a factor graph containing only the factors from fg2 and with null + // factors in the place of those of fg1, so that the factor indices are correct. + SymbolicFactorGraph fg2removed(fgCombined); + fg2removed.remove(0); fg2removed.remove(1); fg2removed.remove(2); fg2removed.remove(3); + + // The expected VariableIndex has the same factor indices as fgCombined but + // with entries from fg1 removed, and still has all 10 variables. + VariableIndex expected(fg2removed, 10); + VariableIndex actual(fgCombined); + vector indices; + indices.push_back(0); indices.push_back(1); indices.push_back(2); indices.push_back(3); + actual.remove(indices, fg1); + + CHECK(assert_equal(expected, actual)); + +} + /* ************************************************************************* */ int main() { TestResult tr;