diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 67c3278a3..6ac63c93f 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -165,6 +165,7 @@ gtsam::Values allPose2s(gtsam::Values& values); Matrix extractPose2(const gtsam::Values& values); gtsam::Values allPose3s(gtsam::Values& values); Matrix extractPose3(const gtsam::Values& values); +Matrix extractVectors(const gtsam::Values& values, char c); void perturbPoint2(gtsam::Values& values, double sigma, int seed = 42u); void perturbPose2(gtsam::Values& values, double sigmaT, double sigmaR, int seed = 42u); diff --git a/gtsam/geometry/tests/testUtilities.cpp b/gtsam/nonlinear/tests/testUtilities.cpp similarity index 68% rename from gtsam/geometry/tests/testUtilities.cpp rename to gtsam/nonlinear/tests/testUtilities.cpp index 25ac3acc8..55a7fdb13 100644 --- a/gtsam/geometry/tests/testUtilities.cpp +++ b/gtsam/nonlinear/tests/testUtilities.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include using namespace gtsam; @@ -55,6 +54,26 @@ TEST(Utilities, ExtractPoint3) { EXPECT_LONGS_EQUAL(2, all_points.rows()); } +/* ************************************************************************* */ +TEST(Utilities, ExtractVector) { + // Test normal case with 3 vectors and 1 non-vector (ignore non-vector) + auto values = Values(); + values.insert(X(0), (Vector(4) << 1, 2, 3, 4).finished()); + values.insert(X(2), (Vector(4) << 13, 14, 15, 16).finished()); + values.insert(X(1), (Vector(4) << 6, 7, 8, 9).finished()); + values.insert(X(3), Pose3()); + auto actual = utilities::extractVectors(values, 'x'); + auto expected = + (Matrix(3, 4) << 1, 2, 3, 4, 6, 7, 8, 9, 13, 14, 15, 16).finished(); + EXPECT(assert_equal(expected, actual)); + + // Check that mis-sized vectors fail + values.insert(X(4), (Vector(2) << 1, 2).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); + values.update(X(4), (Vector(6) << 1, 2, 3, 4, 5, 6).finished()); + THROWS_EXCEPTION(utilities::extractVectors(values, 'x')); +} + /* ************************************************************************* */ int main() { srand(time(nullptr)); diff --git a/gtsam/nonlinear/utilities.h b/gtsam/nonlinear/utilities.h index fdc1da2c4..d2b38d374 100644 --- a/gtsam/nonlinear/utilities.h +++ b/gtsam/nonlinear/utilities.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -162,6 +163,34 @@ Matrix extractPose3(const Values& values) { return result; } +/// Extract all Vector values with a given symbol character into an mxn matrix, +/// where m is the number of symbols that match the character and n is the +/// dimension of the variables. If not all variables have dimension n, then a +/// runtime error will be thrown. The order of returned values are sorted by +/// the symbol. +/// For example, calling extractVector(values, 'x'), where values contains 200 +/// variables x1, x2, ..., x200 of type Vector each 5-dimensional, will return a +/// 200x5 matrix with row i containing xi. +Matrix extractVectors(const Values& values, char c) { + Values::ConstFiltered vectors = + values.filter(Symbol::ChrTest(c)); + if (vectors.size() == 0) { + return Matrix(); + } + auto dim = vectors.begin()->value.size(); + Matrix result(vectors.size(), dim); + Eigen::Index rowi = 0; + for (const auto& kv : vectors) { + if (kv.value.size() != dim) { + throw std::runtime_error( + "Tried to extract different-sized vectors into a single matrix"); + } + result.row(rowi) = kv.value; + ++rowi; + } + return result; +} + /// Perturb all Point2 values using normally distributed noise void perturbPoint2(Values& values, double sigma, int32_t seed = 42u) { noiseModel::Isotropic::shared_ptr model = diff --git a/matlab/gtsam_tests/testUtilities.m b/matlab/gtsam_tests/testUtilities.m index da8dec789..2bfe81a83 100644 --- a/matlab/gtsam_tests/testUtilities.m +++ b/matlab/gtsam_tests/testUtilities.m @@ -45,3 +45,12 @@ CHECK('KeySet', isa(actual,'gtsam.KeySet')); CHECK('size==3', actual.size==3); CHECK('actual.count(x1)', actual.count(x1)); +% test extractVectors +values = Values(); +values.insert(symbol('x', 0), (1:6)'); +values.insert(symbol('x', 1), (7:12)'); +values.insert(symbol('x', 2), (13:18)'); +values.insert(symbol('x', 7), Pose3()); +actual = utilities.extractVectors(values, 'x'); +expected = reshape(1:18, 6, 3)'; +CHECK('extractVectors', all(actual == expected, 'all'));