Removed MPL complexity from UnaryExpression.
parent
b52ced7a09
commit
660acec58e
|
|
@ -629,24 +629,25 @@ struct FunctionalNode {
|
||||||
|
|
||||||
/// Unary Function Expression
|
/// Unary Function Expression
|
||||||
template<class T, class A1>
|
template<class T, class A1>
|
||||||
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type {
|
class UnaryExpression : public ExpressionNode<T> {
|
||||||
|
|
||||||
typedef typename MakeOptionalJacobian<T, A1>::type OJ1;
|
typedef typename MakeOptionalJacobian<T, A1>::type OJ1;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::function<T(const A1&, OJ1)> Function;
|
typedef boost::function<T(const A1&, OJ1)> Function;
|
||||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1> >::type Base;
|
|
||||||
typedef typename Base::Record Record;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
|
boost::shared_ptr<ExpressionNode<A1> > expression1_;
|
||||||
|
|
||||||
|
typedef Argument<T, A1, 1> This; ///< The storage we have direct access to
|
||||||
|
|
||||||
/// Constructor with a unary function f, and input argument e
|
/// Constructor with a unary function f, and input argument e
|
||||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->template reset<A1, 1>(e1.root());
|
this->expression1_ = e1.root();
|
||||||
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
|
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -656,14 +657,96 @@ public:
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
return function_(this->template expression<A1, 1>()->value(values), boost::none);
|
return function_(this->expression1_->value(values), boost::none);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inner Record Class
|
||||||
|
// The reason we inherit from JacobianTrace<T, A, N> is because we can then
|
||||||
|
// case to this unique signature to retrieve the value/trace at any level
|
||||||
|
struct Record: public internal::CallRecordImplementor<Record,
|
||||||
|
traits<T>::dimension>, JacobianTrace<T, A1, 1> {
|
||||||
|
|
||||||
|
typedef T return_type;
|
||||||
|
typedef JacobianTrace<T, A1, 1> This;
|
||||||
|
|
||||||
|
/// Access Jacobian
|
||||||
|
template<class A, size_t N>
|
||||||
|
typename Jacobian<T, A1>::type& jacobian() {
|
||||||
|
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Access Value
|
||||||
|
template<class A, size_t N>
|
||||||
|
const A& value() const {
|
||||||
|
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print to std::cout
|
||||||
|
void print(const std::string& indent) const {
|
||||||
|
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
|
||||||
|
std::cout << This::dTdA.format(matlab) << std::endl;
|
||||||
|
This::trace.print(indent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the reverse AD process
|
||||||
|
void startReverseAD4(JacobianMap& jacobians) const {
|
||||||
|
// This is the crucial point where the size of the AD pipeline is selected.
|
||||||
|
// One pipeline is started for each argument, but the number of rows in each
|
||||||
|
// pipeline is the same, namely the dimension of the output argument T.
|
||||||
|
// For example, if the entire expression is rooted by a binary function
|
||||||
|
// yielding a 2D result, then the matrix dTdA will have 2 rows.
|
||||||
|
// ExecutionTrace::reverseAD1 just passes this on to CallRecord::reverseAD2
|
||||||
|
// which calls the correctly sized CallRecord::reverseAD3, which in turn
|
||||||
|
// calls reverseAD4 below.
|
||||||
|
This::trace.reverseAD1(This::dTdA, jacobians);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
|
// Cols is always known at compile time
|
||||||
|
template<typename SomeMatrix>
|
||||||
|
void reverseAD4(const SomeMatrix & dFdT,
|
||||||
|
JacobianMap& jacobians) const {
|
||||||
|
This::trace.reverseAD1(dFdT * This::dTdA, jacobians);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD
|
||||||
|
void trace(const Values& values, Record* record,
|
||||||
|
ExecutionTraceStorage*& traceStorage) const {
|
||||||
|
// Write an Expression<A> execution trace in record->trace
|
||||||
|
// Iff Constant or Leaf, this will not write to traceStorage, only to trace.
|
||||||
|
// Iff the expression is functional, write all Records in traceStorage buffer
|
||||||
|
// Return value of type T is recorded in record->value
|
||||||
|
record->Record::This::value = this->expression1_->traceExecution(values,
|
||||||
|
record->Record::This::trace, traceStorage);
|
||||||
|
// traceStorage is never modified by traceExecution, but if traceExecution has
|
||||||
|
// written in the buffer, the next caller expects we advance the pointer
|
||||||
|
traceStorage += this->expression1_->traceSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD
|
||||||
|
Record* trace(const Values& values,
|
||||||
|
ExecutionTraceStorage* traceStorage) const {
|
||||||
|
assert(reinterpret_cast<size_t>(traceStorage) % TraceAlignment == 0);
|
||||||
|
|
||||||
|
// Create the record and advance the pointer
|
||||||
|
Record* record = new (traceStorage) Record();
|
||||||
|
traceStorage += upAligned(sizeof(Record));
|
||||||
|
|
||||||
|
// Record the traces for all arguments
|
||||||
|
// After this, the traceStorage pointer is set to after what was written
|
||||||
|
this->trace(values, record, traceStorage);
|
||||||
|
|
||||||
|
// Return the record for this function evaluation
|
||||||
|
return record;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
ExecutionTraceStorage* traceStorage) const {
|
ExecutionTraceStorage* traceStorage) const {
|
||||||
|
|
||||||
Record* record = Base::trace(values, traceStorage);
|
Record* record = this->trace(values, traceStorage);
|
||||||
|
record->print("record: ");
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
return function_(record->template value<A1, 1>(),
|
return function_(record->template value<A1, 1>(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue