diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 45a932351..87c07f976 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -647,45 +647,23 @@ public: } /// Trace structure for reverse AD - struct Trace: public JacobianTrace { - typename JacobianTrace::Pointer trace1; - typename JacobianTrace::Pointer trace2; - typename JacobianTrace::Pointer trace3; - JacobianTA1 dTdA1; - JacobianTA2 dTdA2; - JacobianTA3 dTdA3; - - /// Start the reverse AD process - virtual void startReverseAD(JacobianMap& jacobians) const { - Select::reverseAD(trace1, dTdA1, jacobians); - Select::reverseAD(trace2, dTdA2, jacobians); - Select::reverseAD(trace3, dTdA3, jacobians); - } - /// Given df/dT, multiply in dT/dA and continue reverse AD process - virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { - trace1.reverseAD(dFdT * dTdA1, jacobians); - trace2.reverseAD(dFdT * dTdA2, jacobians); - trace3.reverseAD(dFdT * dTdA3, jacobians); - } - /// Version specialized to 2-dimensional output - typedef Eigen::Matrix Jacobian2T; - virtual void reverseAD2(const Jacobian2T& dFdT, - JacobianMap& jacobians) const { - trace1.reverseAD2(dFdT * dTdA1, jacobians); - trace2.reverseAD2(dFdT * dTdA2, jacobians); - trace3.reverseAD2(dFdT * dTdA3, jacobians); - } - }; + typedef boost::mpl::vector Arguments; + typedef typename GenerateTrace::type Trace; /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, typename JacobianTrace::Pointer& p) const { Trace* trace = new Trace(); p.setFunction(trace); - A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1); - A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2); - A3 a3 = this->expressionA3_->traceExecution(values, trace->trace3); - return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3); + A1 a1 = this->expressionA1_->traceExecution(values, + static_cast*>(trace)->trace); + A2 a2 = this->expressionA2_->traceExecution(values, + static_cast*>(trace)->trace); + A3 a3 = this->expressionA3_->traceExecution(values, + static_cast*>(trace)->trace); + return function_(a1, a2, a3, static_cast*>(trace)->dTdA, + static_cast*>(trace)->dTdA, + static_cast*>(trace)->dTdA); } };