diff --git a/cpp/GaussianFactor.cpp b/cpp/GaussianFactor.cpp index e1cb81933..0cacada3c 100644 --- a/cpp/GaussianFactor.cpp +++ b/cpp/GaussianFactor.cpp @@ -220,35 +220,28 @@ Matrix GaussianFactor::matrix_augmented(const Ordering& ordering) const { /* ************************************************************************* */ boost::tuple, list, list > -GaussianFactor::sparse(const Ordering& ordering, const Dimensions& variables) const { +GaussianFactor::sparse(const Dimensions& columnIndices) const { // declare return values list I,J; list S; - // loop over all variables in correct order - size_t column_start = 1; - BOOST_FOREACH(string key, ordering) { - try { - const Matrix& Aj = get_A(key); - for (size_t i = 0; i < Aj.size1(); i++) { - double sigma_i = sigmas_(i); - for (size_t j = 0; j < Aj.size2(); j++) - if (Aj(i, j) != 0.0) { - I.push_back(i + 1); - J.push_back(j + column_start); - S.push_back(Aj(i, j) / sigma_i); - } - } - } catch (std::invalid_argument& exception) { - // it's ok to not have a key in the ordering - } - // find dimension for this key - Dimensions::const_iterator it = variables.find(key); + // iterate over all matrices in the factor + string key; Matrix Aj; + FOREACH_PAIR( key, Aj, As_) { + // find first column index for this key // TODO: check if end() and throw exception if not found - int dim = it->second; - // advance column index to next block by adding dim(key) - column_start += dim; + Dimensions::const_iterator it = columnIndices.find(key); + int column_start = it->second; + for (size_t i = 0; i < Aj.size1(); i++) { + double sigma_i = sigmas_(i); + for (size_t j = 0; j < Aj.size2(); j++) + if (Aj(i, j) != 0.0) { + I.push_back(i + 1); + J.push_back(j + column_start); + S.push_back(Aj(i, j) / sigma_i); + } + } } // return the result diff --git a/cpp/GaussianFactor.h b/cpp/GaussianFactor.h index c13edcc17..ac78789e3 100644 --- a/cpp/GaussianFactor.h +++ b/cpp/GaussianFactor.h @@ -211,10 +211,10 @@ public: * Return vectors i, j, and s to generate an m-by-n sparse matrix * such that S(i(k),j(k)) = s(k), which can be given to MATLAB's sparse. * As above, the standard deviations are baked into A and b - * @param ordering of variables needed for matrix column order + * @param first column index for each variable */ boost::tuple, std::list, std::list > - sparse(const Ordering& ordering, const Dimensions& variables) const; + sparse(const Dimensions& columnIndices) const; /** * Add gradient contribution to gradient config g diff --git a/cpp/GaussianFactorGraph.cpp b/cpp/GaussianFactorGraph.cpp index 341f6bc9d..a245b20db 100644 --- a/cpp/GaussianFactorGraph.cpp +++ b/cpp/GaussianFactorGraph.cpp @@ -156,6 +156,27 @@ pair GaussianFactorGraph::matrix(const Ordering& ordering) const return lf.matrix(ordering); } +/* ************************************************************************* */ +Dimensions GaussianFactorGraph::columnIndices(const Ordering& ordering) const { + + // get the dimensions for all variables + Dimensions variableSet = dimensions(); + + // Find the starting index and dimensions for all variables given the order + size_t j = 1; + Dimensions result; + BOOST_FOREACH(string key, ordering) { + // associate key with first column index + result.insert(make_pair(key,j)); + // find dimension for this key + Dimensions::const_iterator it = variableSet.find(key); + // advance column index to next block by adding dim(key) + j += it->second; + } + + return result; +} + /* ************************************************************************* */ Matrix GaussianFactorGraph::sparse(const Ordering& ordering) const { @@ -163,8 +184,8 @@ Matrix GaussianFactorGraph::sparse(const Ordering& ordering) const { list I,J; list S; - // get the dimensions for all variables - Dimensions variableSet = dimensions(); + // get the starting column indices for all variables + Dimensions indices = columnIndices(ordering); // Collect the I,J,S lists for all factors int row_index = 0; @@ -173,7 +194,7 @@ Matrix GaussianFactorGraph::sparse(const Ordering& ordering) const { // get sparse lists for the factor list i1,j1; list s1; - boost::tie(i1,j1,s1) = factor->sparse(ordering,variableSet); + boost::tie(i1,j1,s1) = factor->sparse(indices); // add row_start to every row index transform(i1.begin(), i1.end(), i1.begin(), bind2nd(plus(), row_index)); diff --git a/cpp/GaussianFactorGraph.h b/cpp/GaussianFactorGraph.h index de6b212bd..1567f251c 100644 --- a/cpp/GaussianFactorGraph.h +++ b/cpp/GaussianFactorGraph.h @@ -152,6 +152,13 @@ namespace gtsam { */ std::pair matrix (const Ordering& ordering) const; + /** + * get the starting column indices for all variables + * @param ordering of variables needed for matrix column order + * @return The set of all variable/index pairs + */ + Dimensions columnIndices(const Ordering& ordering) const; + /** * Return 3*nzmax matrix where the rows correspond to the vectors i, j, and s * to generate an m-by-n sparse matrix, which can be given to MATLAB's sparse function. diff --git a/cpp/testGaussianFactor.cpp b/cpp/testGaussianFactor.cpp index d89a1abb1..f7f339aed 100644 --- a/cpp/testGaussianFactor.cpp +++ b/cpp/testGaussianFactor.cpp @@ -552,7 +552,7 @@ TEST( GaussianFactor, sparse ) list i,j; list s; - boost::tie(i,j,s) = lf->sparse(ord, fg.dimensions()); + boost::tie(i,j,s) = lf->sparse(fg.columnIndices(ord)); list i1,j1; i1 += 1,2,1,2; @@ -581,14 +581,14 @@ TEST( GaussianFactor, sparse2 ) list i,j; list s; - boost::tie(i,j,s) = lf->sparse(ord, fg.dimensions()); + boost::tie(i,j,s) = lf->sparse(fg.columnIndices(ord)); list i1,j1; i1 += 1,2,1,2; - j1 += 1,2,5,6; + j1 += 5,6,1,2; list s1; - s1 += 10,10,-10,-10; + s1 += -10,-10,10,10; CHECK(i==i1); CHECK(j==j1); diff --git a/cpp/testGaussianFactorGraph.cpp b/cpp/testGaussianFactorGraph.cpp index a916d0df4..ba9dffddb 100644 --- a/cpp/testGaussianFactorGraph.cpp +++ b/cpp/testGaussianFactorGraph.cpp @@ -330,11 +330,11 @@ TEST( GaussianFactorGraph, sparse ) Matrix ijs = fg.sparse(ord); - EQUALITY(ijs, Matrix_(3, 14, + EQUALITY(Matrix_(3, 14, // f(x1) f(x2,x1) f(l1,x1) f(x2,l1) - +1., 2., 3., 4., 3., 4., 5.,6., 5., 6., 7., 8.,7.,8., - +5., 6., 1., 2., 5., 6., 3.,4., 5., 6., 1., 2.,3.,4., - 10.,10., 10.,10.,-10.,-10., 5.,5.,-5.,-5., -5.,-5.,5.,5.)); + +1., 2., 3., 4., 3., 4., 5.,6., 5., 6., 7., 8., 7., 8., + +5., 6., 5., 6., 1., 2., 3.,4., 5., 6., 3., 4., 1., 2., + 10.,10., -10.,-10., 10., 10., 5.,5.,-5.,-5., 5., 5.,-5.,-5.), ijs); } /* ************************************************************************* */