Attempt at defining Trace recursively

release/4.3a0
dellaert 2014-10-09 14:38:16 +02:00
parent 5e5457b390
commit 8e264f4289
1 changed files with 31 additions and 19 deletions

View File

@ -378,6 +378,36 @@ public:
};
//-----------------------------------------------------------------------------
#include <boost/mpl/list.hpp>
/// Recursive Trace Class
template<class T, class LIST>
struct MakeTrace: public JacobianTrace<T> {
typedef boost::mpl::front<LIST> A1;
static const size_t dimA = A1::dimension;
typedef Eigen::Matrix<double, T::dimension, A1::dimension> JacobianTA;
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
typename JacobianTrace<A1>::Pointer trace1;
JacobianTA dTdA1;
/// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const {
Select<T::dimension, A1>::reverseAD(trace1, dTdA1, jacobians);
}
/// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
trace1.reverseAD(dFdT * dTdA1, jacobians);
}
/// Version specialized to 2-dimensional output
virtual void reverseAD2(const Jacobian2T& dFdT,
JacobianMap& jacobians) const {
trace1.reverseAD2(dFdT * dTdA1, jacobians);
}
};
//-----------------------------------------------------------------------------
/// Unary Function Expression
template<class T, class A1>
@ -423,25 +453,7 @@ public:
}
/// Trace structure for reverse AD
struct Trace: public JacobianTrace<T> {
typename JacobianTrace<A1>::Pointer trace1;
JacobianTA dTdA1;
/// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const {
Select<T::dimension, A1>::reverseAD(trace1, dTdA1, jacobians);
}
/// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
trace1.reverseAD(dFdT * dTdA1, jacobians);
}
/// Version specialized to 2-dimensional output
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
virtual void reverseAD2(const Jacobian2T& dFdT,
JacobianMap& jacobians) const {
trace1.reverseAD2(dFdT * dTdA1, jacobians);
}
};
typedef MakeTrace<T, boost::mpl::list1<A1> > Trace;
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values,