Pre-big collapse: prototype recursively defined inner Record2 type
parent
7fde47c48b
commit
bc9e11f43c
|
@ -307,10 +307,6 @@ struct Select<2, A> {
|
||||||
template<class T>
|
template<class T>
|
||||||
class ExpressionNode {
|
class ExpressionNode {
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
static size_t const N = 0; // number of arguments
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
size_t traceSize_;
|
size_t traceSize_;
|
||||||
|
@ -510,7 +506,7 @@ struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
||||||
// to recursively generate a class, that will be the base for function nodes.
|
// to recursively generate a class, that will be the base for function nodes.
|
||||||
// The class generated, for two arguments A1, A2, and A3 will be
|
// The class generated, for two arguments A1, A2, and A3 will be
|
||||||
//
|
//
|
||||||
// struct Base1 : Argument<T,A1,1>, ExpressionNode<T> {
|
// struct Base1 : Argument<T,A1,1>, FunctionalBase<T> {
|
||||||
// ... storage related to A1 ...
|
// ... storage related to A1 ...
|
||||||
// ... methods that work on A1 ...
|
// ... methods that work on A1 ...
|
||||||
// };
|
// };
|
||||||
|
@ -535,7 +531,21 @@ struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Building block for Recursive FunctionalNode Class
|
* Base case for recursive FunctionalNode class
|
||||||
|
*/
|
||||||
|
template<class T>
|
||||||
|
struct FunctionalBase: ExpressionNode<T> {
|
||||||
|
static size_t const N = 0; // number of arguments
|
||||||
|
|
||||||
|
typedef CallRecord<T::dimension> Record2;
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD
|
||||||
|
void trace(const Values& values, Record2* record, char*& raw) const {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Building block for recursive FunctionalNode class
|
||||||
* The integer argument N is to guarantee a unique type signature,
|
* The integer argument N is to guarantee a unique type signature,
|
||||||
* so we are guaranteed to be able to extract their values by static cast.
|
* so we are guaranteed to be able to extract their values by static cast.
|
||||||
*/
|
*/
|
||||||
|
@ -562,14 +572,60 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recursive Record Class for Functional Expressions
|
||||||
|
*/
|
||||||
|
struct Record2: JacobianTrace<T, A, N>, Base::Record2 {
|
||||||
|
|
||||||
|
typedef T return_type;
|
||||||
|
typedef JacobianTrace<T, A, N> This;
|
||||||
|
|
||||||
|
/// Print to std::cout
|
||||||
|
virtual void print(const std::string& indent) const {
|
||||||
|
Base::Record2::print(indent);
|
||||||
|
static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]");
|
||||||
|
std::cout << This::dTdA.format(matlab) << std::endl;
|
||||||
|
This::trace.print(indent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the reverse AD process
|
||||||
|
virtual void startReverseAD(JacobianMap& jacobians) const {
|
||||||
|
Base::Record2::startReverseAD(jacobians);
|
||||||
|
Select<T::dimension, A>::reverseAD(This::trace, This::dTdA, jacobians);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
|
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
|
||||||
|
Base::Record2::reverseAD(dFdT, jacobians);
|
||||||
|
This::trace.reverseAD(dFdT * This::dTdA, jacobians);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Version specialized to 2-dimensional output
|
||||||
|
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
||||||
|
virtual void reverseAD2(const Jacobian2T& dFdT,
|
||||||
|
JacobianMap& jacobians) const {
|
||||||
|
Base::Record2::reverseAD2(dFdT, jacobians);
|
||||||
|
This::trace.reverseAD2(dFdT * This::dTdA, jacobians);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD
|
||||||
|
void trace(const Values& values, Record2* record, char*& raw) const {
|
||||||
|
Base::trace(values, record, raw);
|
||||||
|
A a = This::expression->traceExecution(values, record->Record2::This::trace, raw);
|
||||||
|
raw = raw + This::expression->traceSize();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Recursive GenerateFunctionalNode class Generator
|
* Recursive GenerateFunctionalNode class Generator
|
||||||
*/
|
*/
|
||||||
template<class T, class TYPES>
|
template<class T, class TYPES>
|
||||||
struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
|
struct FunctionalNode {
|
||||||
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type {
|
typedef typename boost::mpl::fold<TYPES, FunctionalBase<T>,
|
||||||
|
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type Base;
|
||||||
|
|
||||||
|
struct type: public Base {
|
||||||
|
|
||||||
/// Reset expression shared pointer
|
/// Reset expression shared pointer
|
||||||
template<class A, size_t N>
|
template<class A, size_t N>
|
||||||
|
@ -583,13 +639,24 @@ struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
|
||||||
return static_cast<Argument<T, A, N> const &>(*this).expression;
|
return static_cast<Argument<T, A, N> const &>(*this).expression;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
/// Construct an execution trace for reverse AD
|
||||||
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
|
char* raw) const {
|
||||||
|
typename Base::Record2* record = new (raw) typename Base::Record2();
|
||||||
|
trace.setFunction(record);
|
||||||
|
raw = (char*) (record + 1);
|
||||||
|
|
||||||
|
this->trace(values, record, raw);
|
||||||
|
|
||||||
|
return T(); // TODO
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Unary Function Expression
|
/// Unary Function Expression
|
||||||
template<class T, class A1>
|
template<class T, class A1>
|
||||||
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> > {
|
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type {
|
||||||
|
|
||||||
/// The automatically generated Base class
|
/// The automatically generated Base class
|
||||||
typedef FunctionalNode<T, boost::mpl::vector<A1> > Base;
|
typedef FunctionalNode<T, boost::mpl::vector<A1> > Base;
|
||||||
|
@ -636,10 +703,11 @@ public:
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
Record* record = new (raw) Record();
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
A1 a1 = this-> template expression<A1, 1>()->traceExecution(values,
|
|
||||||
|
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
|
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||||
|
|
||||||
return function_(a1, record->template jacobian<A1, 1>());
|
return function_(a1, record->template jacobian<A1, 1>());
|
||||||
}
|
}
|
||||||
|
@ -649,7 +717,7 @@ public:
|
||||||
/// Binary Expression
|
/// Binary Expression
|
||||||
|
|
||||||
template<class T, class A1, class A2>
|
template<class T, class A1, class A2>
|
||||||
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > {
|
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> >::type {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -706,13 +774,15 @@ public:
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
Record* record = new (raw) Record();
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
|
|
||||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||||
|
|
||||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||||
record->template trace<A2, 2>(), raw);
|
record->template trace<A2, 2>(), raw);
|
||||||
|
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||||
|
|
||||||
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
||||||
record->template jacobian<A2, 2>());
|
record->template jacobian<A2, 2>());
|
||||||
|
@ -723,7 +793,7 @@ public:
|
||||||
/// Ternary Expression
|
/// Ternary Expression
|
||||||
|
|
||||||
template<class T, class A1, class A2, class A3>
|
template<class T, class A1, class A2, class A3>
|
||||||
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > {
|
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -786,16 +856,19 @@ public:
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
Record* record = new (raw) Record();
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
|
|
||||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||||
|
|
||||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||||
record->template trace<A2, 2>(), raw);
|
record->template trace<A2, 2>(), raw);
|
||||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||||
|
|
||||||
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
||||||
record->template trace<A3, 3>(), raw);
|
record->template trace<A3, 3>(), raw);
|
||||||
|
raw = raw + this->template expression<A3, 3>()->traceSize();
|
||||||
|
|
||||||
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
||||||
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>());
|
||||||
|
|
Loading…
Reference in New Issue