Removed MPL complexity from UnaryExpression.

release/4.3a0
dellaert 2015-05-03 20:40:13 -07:00
parent b52ced7a09
commit 660acec58e
1 changed files with 89 additions and 6 deletions

View File

@ -629,24 +629,25 @@ struct FunctionalNode {
/// Unary Function Expression
template<class T, class A1>
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type {
class UnaryExpression : public ExpressionNode<T> {
typedef typename MakeOptionalJacobian<T, A1>::type OJ1;
public:
typedef boost::function<T(const A1&, OJ1)> Function;
typedef typename FunctionalNode<T, boost::mpl::vector<A1> >::type Base;
typedef typename Base::Record Record;
private:
Function function_;
boost::shared_ptr<ExpressionNode<A1> > expression1_;
typedef Argument<T, A1, 1> This; ///< The storage we have direct access to
/// Constructor with a unary function f, and input argument e
UnaryExpression(Function f, const Expression<A1>& e1) :
function_(f) {
this->template reset<A1, 1>(e1.root());
this->expression1_ = e1.root();
ExpressionNode<T>::traceSize_ = upAligned(sizeof(Record)) + e1.traceSize();
}
@ -656,14 +657,96 @@ public:
/// Return value
virtual T value(const Values& values) const {
return function_(this->template expression<A1, 1>()->value(values), boost::none);
return function_(this->expression1_->value(values), boost::none);
}
// Inner Record Class
// The reason we inherit from JacobianTrace<T, A, N> is because we can then
// case to this unique signature to retrieve the value/trace at any level
struct Record: public internal::CallRecordImplementor<Record,
traits<T>::dimension>, JacobianTrace<T, A1, 1> {
typedef T return_type;
typedef JacobianTrace<T, A1, 1> This;
/// Access Jacobian
template<class A, size_t N>
typename Jacobian<T, A1>::type& jacobian() {
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
}
/// Access Value
template<class A, size_t N>
const A& value() const {
return static_cast<JacobianTrace<T, A, N> const &>(*this).value;
}
/// Print to std::cout
void print(const std::string& indent) const {
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
std::cout << This::dTdA.format(matlab) << std::endl;
This::trace.print(indent);
}
/// Start the reverse AD process
void startReverseAD4(JacobianMap& jacobians) const {
// This is the crucial point where the size of the AD pipeline is selected.
// One pipeline is started for each argument, but the number of rows in each
// pipeline is the same, namely the dimension of the output argument T.
// For example, if the entire expression is rooted by a binary function
// yielding a 2D result, then the matrix dTdA will have 2 rows.
// ExecutionTrace::reverseAD1 just passes this on to CallRecord::reverseAD2
// which calls the correctly sized CallRecord::reverseAD3, which in turn
// calls reverseAD4 below.
This::trace.reverseAD1(This::dTdA, jacobians);
}
/// Given df/dT, multiply in dT/dA and continue reverse AD process
// Cols is always known at compile time
template<typename SomeMatrix>
void reverseAD4(const SomeMatrix & dFdT,
JacobianMap& jacobians) const {
This::trace.reverseAD1(dFdT * This::dTdA, jacobians);
}
};
/// Construct an execution trace for reverse AD
void trace(const Values& values, Record* record,
ExecutionTraceStorage*& traceStorage) const {
// Write an Expression<A> execution trace in record->trace
// Iff Constant or Leaf, this will not write to traceStorage, only to trace.
// Iff the expression is functional, write all Records in traceStorage buffer
// Return value of type T is recorded in record->value
record->Record::This::value = this->expression1_->traceExecution(values,
record->Record::This::trace, traceStorage);
// traceStorage is never modified by traceExecution, but if traceExecution has
// written in the buffer, the next caller expects we advance the pointer
traceStorage += this->expression1_->traceSize();
}
/// Construct an execution trace for reverse AD
Record* trace(const Values& values,
ExecutionTraceStorage* traceStorage) const {
assert(reinterpret_cast<size_t>(traceStorage) % TraceAlignment == 0);
// Create the record and advance the pointer
Record* record = new (traceStorage) Record();
traceStorage += upAligned(sizeof(Record));
// Record the traces for all arguments
// After this, the traceStorage pointer is set to after what was written
this->trace(values, record, traceStorage);
// Return the record for this function evaluation
return record;
}
/// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
ExecutionTraceStorage* traceStorage) const {
Record* record = Base::trace(values, traceStorage);
Record* record = this->trace(values, traceStorage);
record->print("record: ");
trace.setFunction(record);
return function_(record->template value<A1, 1>(),