Ternary works, same caveat
parent
406467e341
commit
58bbce482d
|
@ -647,45 +647,23 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trace structure for reverse AD
|
/// Trace structure for reverse AD
|
||||||
struct Trace: public JacobianTrace<T> {
|
typedef boost::mpl::vector<A1, A2, A3> Arguments;
|
||||||
typename JacobianTrace<A1>::Pointer trace1;
|
typedef typename GenerateTrace<T, Arguments>::type Trace;
|
||||||
typename JacobianTrace<A2>::Pointer trace2;
|
|
||||||
typename JacobianTrace<A3>::Pointer trace3;
|
|
||||||
JacobianTA1 dTdA1;
|
|
||||||
JacobianTA2 dTdA2;
|
|
||||||
JacobianTA3 dTdA3;
|
|
||||||
|
|
||||||
/// Start the reverse AD process
|
|
||||||
virtual void startReverseAD(JacobianMap& jacobians) const {
|
|
||||||
Select<T::dimension, A1>::reverseAD(trace1, dTdA1, jacobians);
|
|
||||||
Select<T::dimension, A2>::reverseAD(trace2, dTdA2, jacobians);
|
|
||||||
Select<T::dimension, A3>::reverseAD(trace3, dTdA3, jacobians);
|
|
||||||
}
|
|
||||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
|
||||||
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
|
|
||||||
trace1.reverseAD(dFdT * dTdA1, jacobians);
|
|
||||||
trace2.reverseAD(dFdT * dTdA2, jacobians);
|
|
||||||
trace3.reverseAD(dFdT * dTdA3, jacobians);
|
|
||||||
}
|
|
||||||
/// Version specialized to 2-dimensional output
|
|
||||||
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
|
||||||
virtual void reverseAD2(const Jacobian2T& dFdT,
|
|
||||||
JacobianMap& jacobians) const {
|
|
||||||
trace1.reverseAD2(dFdT * dTdA1, jacobians);
|
|
||||||
trace2.reverseAD2(dFdT * dTdA2, jacobians);
|
|
||||||
trace3.reverseAD2(dFdT * dTdA3, jacobians);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values,
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
typename JacobianTrace<T>::Pointer& p) const {
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
p.setFunction(trace);
|
p.setFunction(trace);
|
||||||
A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1);
|
A1 a1 = this->expressionA1_->traceExecution(values,
|
||||||
A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2);
|
static_cast<SingleTrace<T, A1>*>(trace)->trace);
|
||||||
A3 a3 = this->expressionA3_->traceExecution(values, trace->trace3);
|
A2 a2 = this->expressionA2_->traceExecution(values,
|
||||||
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
|
static_cast<SingleTrace<T, A2>*>(trace)->trace);
|
||||||
|
A3 a3 = this->expressionA3_->traceExecution(values,
|
||||||
|
static_cast<SingleTrace<T, A3>*>(trace)->trace);
|
||||||
|
return function_(a1, a2, a3, static_cast<SingleTrace<T, A1>*>(trace)->dTdA,
|
||||||
|
static_cast<SingleTrace<T, A2>*>(trace)->dTdA,
|
||||||
|
static_cast<SingleTrace<T, A3>*>(trace)->dTdA);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue