add trace structure for reverse AD in TernaryExpression
parent
0421d05d44
commit
cc3c0fcfec
|
@ -533,6 +533,26 @@ public:
|
||||||
return Augmented<T>(t, H1, argument1.jacobians(), H2, argument2.jacobians(), H3, argument3.jacobians());
|
return Augmented<T>(t, H1, argument1.jacobians(), H2, argument2.jacobians(), H3, argument3.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trace structure for reverse AD
|
||||||
|
struct Trace: public JacobianTrace<T> {
|
||||||
|
boost::shared_ptr<JacobianTrace<A1> > trace1;
|
||||||
|
boost::shared_ptr<JacobianTrace<A2> > trace2;
|
||||||
|
boost::shared_ptr<JacobianTrace<A3> > trace3;
|
||||||
|
Matrix H1, H2, H3;
|
||||||
|
/// Start the reverse AD process
|
||||||
|
virtual void reverseAD(JacobianMap& jacobians) const {
|
||||||
|
trace1->reverseAD(H1, jacobians);
|
||||||
|
trace2->reverseAD(H2, jacobians);
|
||||||
|
trace3->reverseAD(H3, jacobians);
|
||||||
|
}
|
||||||
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
|
virtual void reverseAD(const Matrix& H, JacobianMap& jacobians) const {
|
||||||
|
trace1->reverseAD(H * H1, jacobians);
|
||||||
|
trace2->reverseAD(H * H2, jacobians);
|
||||||
|
trace3->reverseAD(H * H3, jacobians);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue