Ternary works, same caveat
							parent
							
								
									406467e341
								
							
						
					
					
						commit
						58bbce482d
					
				|  | @ -647,45 +647,23 @@ public: | |||
|   } | ||||
| 
 | ||||
|   /// Trace structure for reverse AD
 | ||||
|   struct Trace: public JacobianTrace<T> { | ||||
|     typename JacobianTrace<A1>::Pointer trace1; | ||||
|     typename JacobianTrace<A2>::Pointer trace2; | ||||
|     typename JacobianTrace<A3>::Pointer trace3; | ||||
|     JacobianTA1 dTdA1; | ||||
|     JacobianTA2 dTdA2; | ||||
|     JacobianTA3 dTdA3; | ||||
| 
 | ||||
|     /// Start the reverse AD process
 | ||||
|     virtual void startReverseAD(JacobianMap& jacobians) const { | ||||
|       Select<T::dimension, A1>::reverseAD(trace1, dTdA1, jacobians); | ||||
|       Select<T::dimension, A2>::reverseAD(trace2, dTdA2, jacobians); | ||||
|       Select<T::dimension, A3>::reverseAD(trace3, dTdA3, jacobians); | ||||
|     } | ||||
|     /// Given df/dT, multiply in dT/dA and continue reverse AD process
 | ||||
|     virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { | ||||
|       trace1.reverseAD(dFdT * dTdA1, jacobians); | ||||
|       trace2.reverseAD(dFdT * dTdA2, jacobians); | ||||
|       trace3.reverseAD(dFdT * dTdA3, jacobians); | ||||
|     } | ||||
|     /// Version specialized to 2-dimensional output
 | ||||
|     typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T; | ||||
|     virtual void reverseAD2(const Jacobian2T& dFdT, | ||||
|         JacobianMap& jacobians) const { | ||||
|       trace1.reverseAD2(dFdT * dTdA1, jacobians); | ||||
|       trace2.reverseAD2(dFdT * dTdA2, jacobians); | ||||
|       trace3.reverseAD2(dFdT * dTdA3, jacobians); | ||||
|     } | ||||
|   }; | ||||
|   typedef boost::mpl::vector<A1, A2, A3> Arguments; | ||||
|   typedef typename GenerateTrace<T, Arguments>::type Trace; | ||||
| 
 | ||||
|   /// Construct an execution trace for reverse AD
 | ||||
|   virtual T traceExecution(const Values& values, | ||||
|       typename JacobianTrace<T>::Pointer& p) const { | ||||
|     Trace* trace = new Trace(); | ||||
|     p.setFunction(trace); | ||||
|     A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1); | ||||
|     A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2); | ||||
|     A3 a3 = this->expressionA3_->traceExecution(values, trace->trace3); | ||||
|     return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3); | ||||
|     A1 a1 = this->expressionA1_->traceExecution(values, | ||||
|         static_cast<SingleTrace<T, A1>*>(trace)->trace); | ||||
|     A2 a2 = this->expressionA2_->traceExecution(values, | ||||
|         static_cast<SingleTrace<T, A2>*>(trace)->trace); | ||||
|     A3 a3 = this->expressionA3_->traceExecution(values, | ||||
|         static_cast<SingleTrace<T, A3>*>(trace)->trace); | ||||
|     return function_(a1, a2, a3, static_cast<SingleTrace<T, A1>*>(trace)->dTdA, | ||||
|         static_cast<SingleTrace<T, A2>*>(trace)->dTdA, | ||||
|         static_cast<SingleTrace<T, A3>*>(trace)->dTdA); | ||||
|   } | ||||
| 
 | ||||
| }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue