diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 86a2bfa96..e9addeec7 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -307,10 +307,6 @@ struct Select<2, A> { template class ExpressionNode { -public: - - static size_t const N = 0; // number of arguments - protected: size_t traceSize_; @@ -510,7 +506,7 @@ struct Record: public boost::mpl::fold, // to recursively generate a class, that will be the base for function nodes. // The class generated, for two arguments A1, A2, and A3 will be // -// struct Base1 : Argument, ExpressionNode { +// struct Base1 : Argument, FunctionalBase { // ... storage related to A1 ... // ... methods that work on A1 ... // }; @@ -535,7 +531,21 @@ struct Record: public boost::mpl::fold, //----------------------------------------------------------------------------- /** - * Building block for Recursive FunctionalNode Class + * Base case for recursive FunctionalNode class + */ +template +struct FunctionalBase: ExpressionNode { + static size_t const N = 0; // number of arguments + + typedef CallRecord Record2; + + /// Construct an execution trace for reverse AD + void trace(const Values& values, Record2* record, char*& raw) const { + } +}; + +/** + * Building block for recursive FunctionalNode class * The integer argument N is to guarantee a unique type signature, * so we are guaranteed to be able to extract their values by static cast. */ @@ -562,34 +572,91 @@ struct GenerateFunctionalNode: Argument, Base { return keys; } + /** + * Recursive Record Class for Functional Expressions + */ + struct Record2: JacobianTrace, Base::Record2 { + + typedef T return_type; + typedef JacobianTrace This; + + /// Print to std::cout + virtual void print(const std::string& indent) const { + Base::Record2::print(indent); + static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]"); + std::cout << This::dTdA.format(matlab) << std::endl; + This::trace.print(indent); + } + + /// Start the reverse AD process + virtual void startReverseAD(JacobianMap& jacobians) const { + Base::Record2::startReverseAD(jacobians); + Select::reverseAD(This::trace, This::dTdA, jacobians); + } + + /// Given df/dT, multiply in dT/dA and continue reverse AD process + virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { + Base::Record2::reverseAD(dFdT, jacobians); + This::trace.reverseAD(dFdT * This::dTdA, jacobians); + } + + /// Version specialized to 2-dimensional output + typedef Eigen::Matrix Jacobian2T; + virtual void reverseAD2(const Jacobian2T& dFdT, + JacobianMap& jacobians) const { + Base::Record2::reverseAD2(dFdT, jacobians); + This::trace.reverseAD2(dFdT * This::dTdA, jacobians); + } + }; + + /// Construct an execution trace for reverse AD + void trace(const Values& values, Record2* record, char*& raw) const { + Base::trace(values, record, raw); + A a = This::expression->traceExecution(values, record->Record2::This::trace, raw); + raw = raw + This::expression->traceSize(); + } }; /** * Recursive GenerateFunctionalNode class Generator */ template -struct FunctionalNode: public boost::mpl::fold, - GenerateFunctionalNode >::type { +struct FunctionalNode { + typedef typename boost::mpl::fold, + GenerateFunctionalNode >::type Base; - /// Reset expression shared pointer - template - void reset(const boost::shared_ptr >& ptr) { - static_cast &>(*this).expression = ptr; - } + struct type: public Base { - /// Access Expression shared pointer - template - boost::shared_ptr > expression() const { - return static_cast const &>(*this).expression; - } + /// Reset expression shared pointer + template + void reset(const boost::shared_ptr >& ptr) { + static_cast &>(*this).expression = ptr; + } + /// Access Expression shared pointer + template + boost::shared_ptr > expression() const { + return static_cast const &>(*this).expression; + } + + /// Construct an execution trace for reverse AD + virtual T traceExecution(const Values& values, ExecutionTrace& trace, + char* raw) const { + typename Base::Record2* record = new (raw) typename Base::Record2(); + trace.setFunction(record); + raw = (char*) (record + 1); + + this->trace(values, record, raw); + + return T(); // TODO + } + }; }; - //----------------------------------------------------------------------------- /// Unary Function Expression template -class UnaryExpression: public FunctionalNode > { +class UnaryExpression: public FunctionalNode >::type { /// The automatically generated Base class typedef FunctionalNode > Base; @@ -636,10 +703,11 @@ public: char* raw) const { Record* record = new (raw) Record(); trace.setFunction(record); - raw = (char*) (record + 1); - A1 a1 = this-> template expression()->traceExecution(values, + + A1 a1 = this->template expression()->traceExecution(values, record->template trace(), raw); + raw = raw + this->template expression()->traceSize(); return function_(a1, record->template jacobian()); } @@ -649,7 +717,7 @@ public: /// Binary Expression template -class BinaryExpression: public FunctionalNode > { +class BinaryExpression: public FunctionalNode >::type { public: @@ -706,13 +774,15 @@ public: char* raw) const { Record* record = new (raw) Record(); trace.setFunction(record); - raw = (char*) (record + 1); + A1 a1 = this->template expression()->traceExecution(values, record->template trace(), raw); raw = raw + this->template expression()->traceSize(); + A2 a2 = this->template expression()->traceExecution(values, record->template trace(), raw); + raw = raw + this->template expression()->traceSize(); return function_(a1, a2, record->template jacobian(), record->template jacobian()); @@ -723,7 +793,7 @@ public: /// Ternary Expression template -class TernaryExpression: public FunctionalNode > { +class TernaryExpression: public FunctionalNode >::type { public: @@ -786,16 +856,19 @@ public: char* raw) const { Record* record = new (raw) Record(); trace.setFunction(record); - raw = (char*) (record + 1); + A1 a1 = this->template expression()->traceExecution(values, record->template trace(), raw); raw = raw + this->template expression()->traceSize(); + A2 a2 = this->template expression()->traceExecution(values, record->template trace(), raw); raw = raw + this->template expression()->traceSize(); + A3 a3 = this->template expression()->traceExecution(values, record->template trace(), raw); + raw = raw + this->template expression()->traceSize(); return function_(a1, a2, a3, record->template jacobian(), record->template jacobian(), record->template jacobian());