add trace structure for reverse AD in TernaryExpression
parent
0421d05d44
commit
cc3c0fcfec
|
@ -533,6 +533,26 @@ public:
|
|||
return Augmented<T>(t, H1, argument1.jacobians(), H2, argument2.jacobians(), H3, argument3.jacobians());
|
||||
}
|
||||
|
||||
/// Trace structure for reverse AD
|
||||
struct Trace: public JacobianTrace<T> {
|
||||
boost::shared_ptr<JacobianTrace<A1> > trace1;
|
||||
boost::shared_ptr<JacobianTrace<A2> > trace2;
|
||||
boost::shared_ptr<JacobianTrace<A3> > trace3;
|
||||
Matrix H1, H2, H3;
|
||||
/// Start the reverse AD process
|
||||
virtual void reverseAD(JacobianMap& jacobians) const {
|
||||
trace1->reverseAD(H1, jacobians);
|
||||
trace2->reverseAD(H2, jacobians);
|
||||
trace3->reverseAD(H3, jacobians);
|
||||
}
|
||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||
virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const {
|
||||
trace1->reverseAD(H * H1, jacobians);
|
||||
trace2->reverseAD(H * H2, jacobians);
|
||||
trace3->reverseAD(H * H3, jacobians);
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
//-----------------------------------------------------------------------------
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue