diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index f0301ba4a..f9a0c91bf 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -336,6 +336,12 @@ public: return keys; } + /// Return dimensions for each argument + virtual std::map dimensions() const { + std::map map; + return map; + } + // Return size needed for memory buffer in traceExecution size_t traceSize() const { return traceSize_; @@ -410,6 +416,13 @@ public: return keys; } + /// Return dimensions for each argument + virtual std::map dimensions() const { + std::map map; + map[key_] = T::dimension; + return map; + } + /// Return value virtual T value(const Values& values) const { return values.at(key_); @@ -526,6 +539,14 @@ struct GenerateFunctionalNode: Argument, Base { return keys; } + /// Return dimensions for each argument + virtual std::map dimensions() const { + std::map map = Base::dimensions(); + std::map myMap = This::expression->dimensions(); + map.insert(myMap.begin(), myMap.end()); + return map; + } + /// Recursive Record Class for Functional Expressions struct Record: JacobianTrace, Base::Record { diff --git a/gtsam_unstable/nonlinear/Expression.h b/gtsam_unstable/nonlinear/Expression.h index 23621f2bb..2f6367734 100644 --- a/gtsam_unstable/nonlinear/Expression.h +++ b/gtsam_unstable/nonlinear/Expression.h @@ -107,6 +107,11 @@ public: return root_->keys(); } + /// Return dimensions for each argument, as a map (allows order to change later) + std::map dimensions() const { + return root_->dimensions(); + } + /// Return value and derivatives, forward AD version Augmented forward(const Values& values) const { return root_->forward(values); diff --git a/gtsam_unstable/nonlinear/tests/testExpression.cpp b/gtsam_unstable/nonlinear/tests/testExpression.cpp index bf13749b9..e6fd12ab4 100644 --- a/gtsam_unstable/nonlinear/tests/testExpression.cpp +++ b/gtsam_unstable/nonlinear/tests/testExpression.cpp @@ -26,9 +26,15 @@ #include +#include +using boost::assign::list_of; +using boost::assign::map_list_of; + using namespace std; using namespace gtsam; +typedef pair Pair; + /* ************************************************************************* */ template @@ -94,13 +100,18 @@ Expression p_cam(x, &Pose3::transform_to, p); } /* ************************************************************************* */ // keys -TEST(Expression, keys_binary) { - - // Check keys - set expectedKeys; - expectedKeys.insert(1); - expectedKeys.insert(2); - EXPECT(expectedKeys == binary::p_cam.keys()); +TEST(Expression, BinaryKeys) { + set expected = list_of(1)(2); + EXPECT(expected == binary::p_cam.keys()); +} +/* ************************************************************************* */ +// dimensions +TEST(Expression, BinaryDimensions) { + map expected = map_list_of(1, 6)(2, 3), // + actual = binary::p_cam.dimensions(); + EXPECT_LONGS_EQUAL(expected.size(),actual.size()); + BOOST_FOREACH(Pair pair, actual) + EXPECT_LONGS_EQUAL(expected[pair.first],pair.second); } /* ************************************************************************* */ // Binary(Leaf,Unary(Binary(Leaf,Leaf))) @@ -115,14 +126,18 @@ Expression uv_hat(uncalibrate, K, projection); } /* ************************************************************************* */ // keys -TEST(Expression, keys_tree) { - - // Check keys - set expectedKeys; - expectedKeys.insert(1); - expectedKeys.insert(2); - expectedKeys.insert(3); - EXPECT(expectedKeys == tree::uv_hat.keys()); +TEST(Expression, TreeKeys) { + set expected = list_of(1)(2)(3); + EXPECT(expected == tree::uv_hat.keys()); +} +/* ************************************************************************* */ +// dimensions +TEST(Expression, TreeDimensions) { + map expected = map_list_of(1, 6)(2, 3)(3, 5), // + actual = tree::uv_hat.dimensions(); + EXPECT_LONGS_EQUAL(expected.size(),actual.size()); + BOOST_FOREACH(Pair pair, actual) + EXPECT_LONGS_EQUAL(expected[pair.first],pair.second); } /* ************************************************************************* */ @@ -133,10 +148,8 @@ TEST(Expression, compose1) { Expression R3 = R1 * R2; // Check keys - set expectedKeys; - expectedKeys.insert(1); - expectedKeys.insert(2); - EXPECT(expectedKeys == R3.keys()); + set expected = list_of(1)(2); + EXPECT(expected == R3.keys()); } /* ************************************************************************* */ @@ -148,9 +161,8 @@ TEST(Expression, compose2) { Expression R3 = R1 * R2; // Check keys - set expectedKeys; - expectedKeys.insert(1); - EXPECT(expectedKeys == R3.keys()); + set expected = list_of(1); + EXPECT(expected == R3.keys()); } /* ************************************************************************* */ @@ -162,9 +174,8 @@ TEST(Expression, compose3) { Expression R3 = R1 * R2; // Check keys - set expectedKeys; - expectedKeys.insert(3); - EXPECT(expectedKeys == R3.keys()); + set expected = list_of(3); + EXPECT(expected == R3.keys()); } /* ************************************************************************* */ @@ -189,11 +200,8 @@ TEST(Expression, ternary) { Expression ABC(composeThree, A, B, C); // Check keys - set expectedKeys; - expectedKeys.insert(1); - expectedKeys.insert(2); - expectedKeys.insert(3); - EXPECT(expectedKeys == ABC.keys()); + set expected = list_of(1)(2)(3); + EXPECT(expected == ABC.keys()); } /* ************************************************************************* */