Refactoring for readability/sanity

release/4.3a0
dellaert 2014-10-11 21:33:07 +02:00
parent 1bac83381f
commit c776e87f78
1 changed files with 36 additions and 36 deletions

View File

@ -198,36 +198,29 @@ struct Argument {
template<class T, class AN, class More>
struct Record: Argument<T, typename AN::type, AN::value>, More {
typedef T return_type;
typedef typename AN::type A;
const static size_t N = AN::value;
ExecutionTrace<A> const & myTrace() const {
return static_cast<const Argument<T, A, AN::value>*>(this)->trace;
}
typedef Eigen::Matrix<double, T::dimension, A::dimension> JacobianTA;
const JacobianTA& myJacobian() const {
return static_cast<const Argument<T, A, AN::value>*>(this)->dTdA;
}
typedef Argument<T, A, N> This;
/// Print to std::cout
virtual void print(const std::string& indent) const {
More::print(indent);
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
std::cout << myJacobian().format(matlab) << std::endl;
myTrace().print(indent);
std::cout << This::dTdA.format(matlab) << std::endl;
This::trace.print(indent);
}
/// Start the reverse AD process
virtual void startReverseAD(JacobianMap& jacobians) const {
More::startReverseAD(jacobians);
Select<T::dimension, A>::reverseAD(myTrace(), myJacobian(), jacobians);
Select<T::dimension, A>::reverseAD(This::trace, This::dTdA, jacobians);
}
/// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
More::reverseAD(dFdT, jacobians);
myTrace().reverseAD(dFdT * myJacobian(), jacobians);
This::trace.reverseAD(dFdT * This::dTdA, jacobians);
}
/// Version specialized to 2-dimensional output
@ -235,7 +228,7 @@ struct Record: Argument<T, typename AN::type, AN::value>, More {
virtual void reverseAD2(const Jacobian2T& dFdT,
JacobianMap& jacobians) const {
More::reverseAD2(dFdT, jacobians);
myTrace().reverseAD2(dFdT * myJacobian(), jacobians);
This::trace.reverseAD2(dFdT * This::dTdA, jacobians);
}
};
@ -252,9 +245,27 @@ template<class T, class TYPES>
struct GenerateRecord {
typedef typename boost::mpl::fold<TYPES, CallRecord<T::dimension>,
Record<T, MPL::_2, MPL::_1> >::type type;
};
/// Access Argument
template<class A, size_t N, class Record>
Argument<typename Record::return_type, A, N>& argument(Record& record) {
return static_cast<Argument<typename Record::return_type, A, N>&>(record);
}
/// Access Trace
template<class A, size_t N, class Record>
ExecutionTrace<A>& getTrace(Record* record) {
return argument<A, N>(*record).trace;
}
/// Access Jacobian
template<class A, size_t N, class Record>
Eigen::Matrix<double, Record::return_type::dimension, A::dimension>& jacobian(
Record* record) {
return argument<A, N>(*record).dTdA;
}
//-----------------------------------------------------------------------------
/**
* Value and Jacobians
@ -552,10 +563,9 @@ public:
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->expressionA1_->traceExecution(values,
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
return function_(a1, static_cast<Argument<T, A1, 1>*>(record)->dTdA);
return function_(a1, jacobian<A1, 1>(record));
}
};
@ -636,15 +646,11 @@ public:
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->expressionA1_->traceExecution(values,
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
raw = raw + expressionA1_->traceSize();
A2 a2 = this->expressionA2_->traceExecution(values,
static_cast<Argument<T, A2, 2>*>(record)->trace, raw);
A2 a2 = expressionA2_->traceExecution(values, getTrace<A2, 2>(record), raw);
return function_(a1, a2, static_cast<Argument<T, A1, 1>*>(record)->dTdA,
static_cast<Argument<T, A2, 2>*>(record)->dTdA);
return function_(a1, a2, jacobian<A1, 1>(record), jacobian<A2, 2>(record));
}
};
@ -736,20 +742,14 @@ public:
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = this->expressionA1_->traceExecution(values,
static_cast<Argument<T, A1, 1>*>(record)->trace, raw);
A1 a1 = expressionA1_->traceExecution(values, getTrace<A1, 1>(record), raw);
raw = raw + expressionA1_->traceSize();
A2 a2 = this->expressionA2_->traceExecution(values,
static_cast<Argument<T, A2, 2>*>(record)->trace, raw);
A2 a2 = expressionA2_->traceExecution(values, getTrace<A2, 2>(record), raw);
raw = raw + expressionA2_->traceSize();
A3 a3 = this->expressionA3_->traceExecution(values,
static_cast<Argument<T, A3, 3>*>(record)->trace, raw);
A3 a3 = expressionA3_->traceExecution(values, getTrace<A3, 3>(record), raw);
return function_(a1, a2, a3, static_cast<Argument<T, A1, 1>*>(record)->dTdA,
static_cast<Argument<T, A2, 2>*>(record)->dTdA,
static_cast<Argument<T, A3, 3>*>(record)->dTdA);
return function_(a1, a2, a3, jacobian<A1, 1>(record),
jacobian<A2, 2>(record), jacobian<A3, 3>(record));
}
};