diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 85920735a..c40dfb405 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -198,36 +198,29 @@ struct Argument { template struct Record: Argument, More { + typedef T return_type; typedef typename AN::type A; const static size_t N = AN::value; - - ExecutionTrace const & myTrace() const { - return static_cast*>(this)->trace; - } - - typedef Eigen::Matrix JacobianTA; - const JacobianTA& myJacobian() const { - return static_cast*>(this)->dTdA; - } + typedef Argument This; /// Print to std::cout virtual void print(const std::string& indent) const { More::print(indent); static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]"); - std::cout << myJacobian().format(matlab) << std::endl; - myTrace().print(indent); + std::cout << This::dTdA.format(matlab) << std::endl; + This::trace.print(indent); } /// Start the reverse AD process virtual void startReverseAD(JacobianMap& jacobians) const { More::startReverseAD(jacobians); - Select::reverseAD(myTrace(), myJacobian(), jacobians); + Select::reverseAD(This::trace, This::dTdA, jacobians); } /// Given df/dT, multiply in dT/dA and continue reverse AD process virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { More::reverseAD(dFdT, jacobians); - myTrace().reverseAD(dFdT * myJacobian(), jacobians); + This::trace.reverseAD(dFdT * This::dTdA, jacobians); } /// Version specialized to 2-dimensional output @@ -235,7 +228,7 @@ struct Record: Argument, More { virtual void reverseAD2(const Jacobian2T& dFdT, JacobianMap& jacobians) const { More::reverseAD2(dFdT, jacobians); - myTrace().reverseAD2(dFdT * myJacobian(), jacobians); + This::trace.reverseAD2(dFdT * This::dTdA, jacobians); } }; @@ -252,9 +245,27 @@ template struct GenerateRecord { typedef typename boost::mpl::fold, Record >::type type; - }; +/// Access Argument +template +Argument& argument(Record& record) { + return static_cast&>(record); +} + +/// Access Trace +template +ExecutionTrace& getTrace(Record* record) { + return argument(*record).trace; +} + +/// Access Jacobian +template +Eigen::Matrix& jacobian( + Record* record) { + return argument(*record).dTdA; +} + //----------------------------------------------------------------------------- /** * Value and Jacobians @@ -552,10 +563,9 @@ public: trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = this->expressionA1_->traceExecution(values, - static_cast*>(record)->trace, raw); + A1 a1 = expressionA1_->traceExecution(values, getTrace(record), raw); - return function_(a1, static_cast*>(record)->dTdA); + return function_(a1, jacobian(record)); } }; @@ -636,15 +646,11 @@ public: trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = this->expressionA1_->traceExecution(values, - static_cast*>(record)->trace, raw); - + A1 a1 = expressionA1_->traceExecution(values, getTrace(record), raw); raw = raw + expressionA1_->traceSize(); - A2 a2 = this->expressionA2_->traceExecution(values, - static_cast*>(record)->trace, raw); + A2 a2 = expressionA2_->traceExecution(values, getTrace(record), raw); - return function_(a1, a2, static_cast*>(record)->dTdA, - static_cast*>(record)->dTdA); + return function_(a1, a2, jacobian(record), jacobian(record)); } }; @@ -736,20 +742,14 @@ public: trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = this->expressionA1_->traceExecution(values, - static_cast*>(record)->trace, raw); - + A1 a1 = expressionA1_->traceExecution(values, getTrace(record), raw); raw = raw + expressionA1_->traceSize(); - A2 a2 = this->expressionA2_->traceExecution(values, - static_cast*>(record)->trace, raw); - + A2 a2 = expressionA2_->traceExecution(values, getTrace(record), raw); raw = raw + expressionA2_->traceSize(); - A3 a3 = this->expressionA3_->traceExecution(values, - static_cast*>(record)->trace, raw); + A3 a3 = expressionA3_->traceExecution(values, getTrace(record), raw); - return function_(a1, a2, a3, static_cast*>(record)->dTdA, - static_cast*>(record)->dTdA, - static_cast*>(record)->dTdA); + return function_(a1, a2, a3, jacobian(record), + jacobian(record), jacobian(record)); } };