Refactoring for readability/sanity
parent
1bac83381f
commit
c776e87f78
|
@ -198,36 +198,29 @@ struct Argument {
|
|||
template<class T, class AN, class More>
|
||||
struct Record: Argument<T, typename AN::type, AN::value>, More {
|
||||
|
||||
typedef T return_type;
|
||||
typedef typename AN::type A;
|
||||
const static size_t N = AN::value;
|
||||
|
||||
ExecutionTrace<A> const & myTrace() const {
|
||||
return static_cast<const Argument<T, A, AN::value>*>(this)->trace;
|
||||
}
|
||||
|
||||
typedef Eigen::Matrix<double, T::dimension, A::dimension> JacobianTA;
|
||||
const JacobianTA& myJacobian() const {
|
||||
return static_cast<const Argument<T, A, AN::value>*>(this)->dTdA;
|
||||
}
|
||||
typedef Argument<T, A, N> This;
|
||||
|
||||
/// Print to std::cout
|
||||
virtual void print(const std::string& indent) const {
|
||||
More::print(indent);
|
||||
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
|
||||
std::cout << myJacobian().format(matlab) << std::endl;
|
||||
myTrace().print(indent);
|
||||
std::cout << This::dTdA.format(matlab) << std::endl;
|
||||
This::trace.print(indent);
|
||||
}
|
||||
|
||||
/// Start the reverse AD process
|
||||
virtual void startReverseAD(JacobianMap& jacobians) const {
|
||||
More::startReverseAD(jacobians);
|
||||
Select<T::dimension, A>::reverseAD(myTrace(), myJacobian(), jacobians);
|
||||
Select<T::dimension, A>::reverseAD(This::trace, This::dTdA, jacobians);
|
||||
}
|
||||
|
||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
|
||||
More::reverseAD(dFdT, jacobians);
|
||||
myTrace().reverseAD(dFdT * myJacobian(), jacobians);
|
||||
This::trace.reverseAD(dFdT * This::dTdA, jacobians);
|
||||
}
|
||||
|
||||
/// Version specialized to 2-dimensional output
|
||||
|
@ -235,7 +228,7 @@ struct Record: Argument<T, typename AN::type, AN::value>, More {
|
|||
virtual void reverseAD2(const Jacobian2T& dFdT,
|
||||
JacobianMap& jacobians) const {
|
||||
More::reverseAD2(dFdT, jacobians);
|
||||
myTrace().reverseAD2(dFdT * myJacobian(), jacobians);
|
||||
This::trace.reverseAD2(dFdT * This::dTdA, jacobians);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -252,9 +245,27 @@ template<class T, class TYPES>
|
|||
struct GenerateRecord {
|
||||
typedef typename boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
||||
Record<T, MPL::_2, MPL::_1> >::type type;
|
||||
|
||||
};
|
||||
|
||||
/// Access Argument
|
||||
template<class A, size_t N, class Record>
|
||||
Argument<typename Record::return_type, A, N>& argument(Record& record) {
|
||||
return static_cast<Argument<typename Record::return_type, A, N>&>(record);
|
||||
}
|
||||
|
||||
/// Access Trace
|
||||
template<class A, size_t N, class Record>
|
||||
ExecutionTrace<A>& getTrace(Record* record) {
|
||||
return argument<A, N>(*record).trace;
|
||||
}
|
||||
|
||||
/// Access Jacobian
|
||||
template<class A, size_t N, class Record>
|
||||
Eigen::Matrix<double, Record::return_type::dimension, A::dimension>& jacobian(
|
||||
Record* record) {
|
||||
return argument<A, N>(*record).dTdA;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
/**
|
||||
* Value and Jacobians
|
||||
|
@ -552,10 +563,9 @@ public:
|
|||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = this->expressionA1_->traceExecution(values,
|
||||
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
|
||||
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
|
||||
|
||||
return function_(a1, static_cast<Argument<T, A1, 1>*>(record)->dTdA);
|
||||
return function_(a1, jacobian<A1, 1>(record));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -636,15 +646,11 @@ public:
|
|||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = this->expressionA1_->traceExecution(values,
|
||||
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
|
||||
|
||||
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
|
||||
raw = raw + expressionA1_->traceSize();
|
||||
A2 a2 = this->expressionA2_->traceExecution(values,
|
||||
static_cast<Argument<T, A2, 2>*>(record)->trace, raw);
|
||||
A2 a2 = expressionA2_->traceExecution(values, getTrace<A2, 2>(record), raw);
|
||||
|
||||
return function_(a1, a2, static_cast<Argument<T, A1, 1>*>(record)->dTdA,
|
||||
static_cast<Argument<T, A2, 2>*>(record)->dTdA);
|
||||
return function_(a1, a2, jacobian<A1, 1>(record), jacobian<A2, 2>(record));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -736,20 +742,14 @@ public:
|
|||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = this->expressionA1_->traceExecution(values,
|
||||
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
|
||||
|
||||
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
|
||||
raw = raw + expressionA1_->traceSize();
|
||||
A2 a2 = this->expressionA2_->traceExecution(values,
|
||||
static_cast<Argument<T, A2, 2>*>(record)->trace, raw);
|
||||
|
||||
A2 a2 = expressionA2_->traceExecution(values, getTrace<A2, 2>(record), raw);
|
||||
raw = raw + expressionA2_->traceSize();
|
||||
A3 a3 = this->expressionA3_->traceExecution(values,
|
||||
static_cast<Argument<T, A3, 3>*>(record)->trace, raw);
|
||||
A3 a3 = expressionA3_->traceExecution(values, getTrace<A3, 3>(record), raw);
|
||||
|
||||
return function_(a1, a2, a3, static_cast<Argument<T, A1, 1>*>(record)->dTdA,
|
||||
static_cast<Argument<T, A2, 2>*>(record)->dTdA,
|
||||
static_cast<Argument<T, A3, 3>*>(record)->dTdA);
|
||||
return function_(a1, a2, a3, jacobian<A1, 1>(record),
|
||||
jacobian<A2, 2>(record), jacobian<A3, 3>(record));
|
||||
}
|
||||
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue