diff --git a/gtsam/base/Manifold.h b/gtsam/base/Manifold.h index 8ac678e65..2f8dc5f68 100644 --- a/gtsam/base/Manifold.h +++ b/gtsam/base/Manifold.h @@ -46,7 +46,7 @@ struct manifold_tag {}; * There may be multiple possible retractions for a given manifold, which can be chosen * between depending on the computational complexity. The important criteria for * the creation for the retract and localCoordinates functions is that they be - * inverse operations. The new notion of a Chart guarantees that. + * inverse operations. * */ @@ -90,9 +90,9 @@ struct ManifoldImpl { /// A helper that implements the traits interface for GTSAM manifolds. /// To use this for your class type, define: -/// template<> struct traits : public internal::Manifold { }; +/// template<> struct traits : public internal::ManifoldTraits { }; template -struct Manifold: Testable, ManifoldImpl { +struct ManifoldTraits: ManifoldImpl { // Check that Class has the necessary machinery BOOST_CONCEPT_ASSERT((HasManifoldPrereqs)); @@ -116,6 +116,11 @@ struct Manifold: Testable, ManifoldImpl { } }; +/// Implement both manifold and testable traits at the same time +template +struct Manifold: Testable, ManifoldTraits { +}; + } // \ namespace internal /// Check invariants for Manifold type @@ -173,33 +178,37 @@ class ProductManifold: public std::pair { BOOST_CONCEPT_ASSERT((IsManifold)); private: - const M1& g() const {return this->first;} - const M2& h() const {return this->second;} + enum { dimension1 = traits::dimension }; + enum { dimension2 = traits::dimension }; public: - enum { dimension = M1::dimension + M2::dimension }; + enum { dimension = dimension1 + dimension2 }; inline static size_t Dim() { return dimension;} inline size_t dim() const { return dimension;} typedef Eigen::Matrix TangentVector; + typedef OptionalJacobian ChartJacobian; /// Default constructor yields identity ProductManifold():std::pair(traits::Identity(),traits::Identity()) {} - // Construct from two subgroup elements - ProductManifold(const M1& g, const M2& h):std::pair(g,h) {} + // Construct from two original manifold values + ProductManifold(const M1& m1, const M2& m2):std::pair(m1,m2) {} /// Retract delta to manifold Derived retract(const TangentVector& xi) const { - return Derived(traits::Retract(g(),xi.head(M1::dimension)), - traits::Retract(h(),xi.tail(M2::dimension))); + M1 m1 = traits::Retract(this->first, xi.template head()); + M2 m2 = traits::Retract(this->second, xi.template tail()); + return Derived(m1,m2); } /// Compute the coordinates in the tangent space TangentVector localCoordinates(const Derived& other) const { - TangentVector xi; - xi << traits::Local(g(),other.g()), traits::Local(h(),other.h()); - return xi; + typename traits::TangentVector v1 = traits::Local(this->first, other.first); + typename traits::TangentVector v2 = traits::Local(this->second, other.second); + TangentVector v; + v << v1, v2; + return v; } }; diff --git a/tests/testManifold.cpp b/tests/testManifold.cpp index ef0456146..496579b8d 100644 --- a/tests/testManifold.cpp +++ b/tests/testManifold.cpp @@ -148,6 +148,37 @@ TEST(Manifold, DefaultChart) { EXPECT(assert_equal(zero(3), traits::Local(R, R))); } +//****************************************************************************** +struct MyPoint2Pair : public ProductManifold { + typedef ProductManifold Base; + MyPoint2Pair(const Point2& p1, const Point2& p2):Base(p1,p2) {} + MyPoint2Pair(const Base& base):Base(base) {} + MyPoint2Pair() {} +}; + +// Define any direct product group to be a model of the multiplicative Group concept +namespace gtsam { +template<> struct traits : internal::ManifoldTraits { + static void Print(const MyPoint2Pair& m, const string& s = "") { + cout << s << "(" << m.first << "," << m.second << ")" << endl; + } + static bool Equals(const MyPoint2Pair& m1, const MyPoint2Pair& m2, double tol = 1e-8) { + return m1 == m2; + } +}; +} + +TEST(Manifold, ProductManifold) { + BOOST_CONCEPT_ASSERT((IsManifold)); + MyPoint2Pair pair1; + Vector4 d; + d << 1,2,3,4; + MyPoint2Pair expected(Point2(1,2),Point2(3,4)); + MyPoint2Pair pair2 = pair1.retract(d); + EXPECT(assert_equal(expected,pair2,1e-9)); + EXPECT(assert_equal(d, pair1.localCoordinates(pair2),1e-9)); +} + //****************************************************************************** int main() { TestResult tr;