From 303d37a7161c30074ce56d3e213dc66cbdce3a98 Mon Sep 17 00:00:00 2001 From: dellaert Date: Sun, 5 Oct 2014 11:22:14 +0200 Subject: [PATCH] Separate hierarchy --- gtsam_unstable/nonlinear/Expression-inl.h | 89 ++++++++++++++++++- gtsam_unstable/nonlinear/Expression.h | 4 +- .../nonlinear/tests/testExpression.cpp | 10 ++- 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 3cb735b6e..7f371b886 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -107,6 +107,84 @@ public: } }; +//----------------------------------------------------------------------------- +/** + * Execution trace for reverse AD + */ +template +class JacobianTrace { + +public: + + /// Constructor + JacobianTrace() { + } + + virtual ~JacobianTrace() { + } + + /// Return value + const T& value() const = 0; + + /// Return value and derivatives + virtual Augmented augmented() const = 0; +}; + +template +class JacobianTraceConstant : public JacobianTrace { + +protected: + + T constant_; + +public: + + /// Constructor + JacobianTraceConstant(const T& constant) : + constant_(constant) { + } + + virtual ~JacobianTraceConstant() { + } + + /// Return value + const T& value() const { + return constant_; + } + + /// Return value and derivatives + virtual Augmented augmented() const { + return Augmented(constant_); + } +}; + +template +class JacobianTraceLeaf : public JacobianTrace { + +protected: + + T value_; + +public: + + /// Constructor + JacobianTraceLeaf(const T& value) : + value_(value) { + } + + virtual ~JacobianTraceLeaf() { + } + + /// Return value + const T& value() const { + return value_; + } + + /// Return value and derivatives + virtual Augmented augmented() const { + return Augmented(value_); + } +}; //----------------------------------------------------------------------------- /** * Expression node. The superclass for objects that do the heavy lifting @@ -137,6 +215,10 @@ public: /// Return value and derivatives virtual Augmented forward(const Values& values) const = 0; + /// Construct an execution trace for reverse AD + virtual JacobianTrace reverse(const Values& values) const { + return JacobianTrace(T()); + } }; //----------------------------------------------------------------------------- @@ -173,10 +255,13 @@ public: /// Return value and derivatives virtual Augmented forward(const Values& values) const { - T t = value(values); - return Augmented(t); + return Augmented(constant_); } + /// Construct an execution trace for reverse AD + virtual JacobianTrace reverse(const Values& values) const { + return JacobianTrace(constant_); + } }; //----------------------------------------------------------------------------- diff --git a/gtsam_unstable/nonlinear/Expression.h b/gtsam_unstable/nonlinear/Expression.h index 27f51893c..bd17febf0 100644 --- a/gtsam_unstable/nonlinear/Expression.h +++ b/gtsam_unstable/nonlinear/Expression.h @@ -103,7 +103,9 @@ public: /// Return value and derivatives Augmented augmented(const Values& values) const { - return root_->forward(values); + JacobianTrace trace = root_->reverse(values); + return trace.augmented(); +// return root_->forward(values); } const boost::shared_ptr >& root() const { diff --git a/gtsam_unstable/nonlinear/tests/testExpression.cpp b/gtsam_unstable/nonlinear/tests/testExpression.cpp index 057359155..19a54c755 100644 --- a/gtsam_unstable/nonlinear/tests/testExpression.cpp +++ b/gtsam_unstable/nonlinear/tests/testExpression.cpp @@ -36,13 +36,15 @@ Point2 uncalibrate(const CAL& K, const Point2& p, boost::optional Dcal, return K.uncalibrate(p, Dcal, Dp); } +static const Rot3 someR = Rot3::RzRyRx(1,2,3); + /* ************************************************************************* */ TEST(Expression, constant) { - Expression R(Rot3::identity()); + Expression R(someR); Values values; Augmented a = R.augmented(values); - EXPECT(assert_equal(Rot3::identity(), a.value())); + EXPECT(assert_equal(someR, a.value())); JacobianMap expected; EXPECT(a.jacobians() == expected); } @@ -52,9 +54,9 @@ TEST(Expression, constant) { TEST(Expression, leaf) { Expression R(100); Values values; - values.insert(100,Rot3::identity()); + values.insert(100,someR); Augmented a = R.augmented(values); - EXPECT(assert_equal(Rot3::identity(), a.value())); + EXPECT(assert_equal(someR, a.value())); JacobianMap expected; expected[100] = eye(3); EXPECT(a.jacobians() == expected);