FunctionalNode inherited for all three functional ExpressionNode sub-classes
parent
8100d89094
commit
a9d9fcd241
|
@ -489,22 +489,16 @@ template<class T, class TYPES>
|
|||
struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
||||
GenerateRecord<T, MPL::_2, MPL::_1> >::type {
|
||||
|
||||
/// Access JacobianTrace
|
||||
template<class A, size_t N>
|
||||
JacobianTrace<typename Record::return_type, A, N>& jacobianTrace() {
|
||||
return static_cast<JacobianTrace<T, A, N>&>(*this);
|
||||
}
|
||||
|
||||
/// Access Trace
|
||||
template<class A, size_t N>
|
||||
ExecutionTrace<A>& trace() {
|
||||
return jacobianTrace<A, N>().trace;
|
||||
return static_cast<JacobianTrace<T, A, N>&>(*this).trace;
|
||||
}
|
||||
|
||||
/// Access Jacobian
|
||||
template<class A, size_t N>
|
||||
Eigen::Matrix<double, T::dimension, A::dimension>& jacobian() {
|
||||
return jacobianTrace<A, N>().dTdA;
|
||||
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -534,20 +528,21 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
|||
template<class T, class TYPES>
|
||||
struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
|
||||
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type {
|
||||
|
||||
/// Access Expression
|
||||
template<class A, size_t N>
|
||||
boost::shared_ptr<ExpressionNode<A> > expression() {
|
||||
return static_cast<Argument<T, A, N> &>(*this).expression;
|
||||
}
|
||||
|
||||
/// Access Expression, const version
|
||||
template<class A, size_t N>
|
||||
boost::shared_ptr<ExpressionNode<A> > expression() const {
|
||||
return static_cast<Argument<T, A, N> const &>(*this).expression;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/// Access Argument
|
||||
template<class A, size_t N, class Record>
|
||||
Argument<typename Record::return_type, A, N>& argument(Record& record) {
|
||||
return static_cast<Argument<typename Record::return_type, A, N>&>(record);
|
||||
}
|
||||
|
||||
/// Access Expression
|
||||
template<class A, size_t N, class Record>
|
||||
ExecutionTrace<A>& expression(Record* record) {
|
||||
return argument<A, N>(*record).expression;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// Unary Function Expression
|
||||
|
@ -562,11 +557,11 @@ public:
|
|||
private:
|
||||
|
||||
Function function_;
|
||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
||||
|
||||
/// Constructor with a unary function f, and input argument e
|
||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||
function_(f), expressionA1_(e1.root()) {
|
||||
function_(f) {
|
||||
this->template expression<A1, 1>() = e1.root();
|
||||
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
|
||||
}
|
||||
|
||||
|
@ -576,18 +571,18 @@ public:
|
|||
|
||||
/// Return keys that play in this expression
|
||||
virtual std::set<Key> keys() const {
|
||||
return expressionA1_->keys();
|
||||
return this->template expression<A1, 1>()->keys();
|
||||
}
|
||||
|
||||
/// Return value
|
||||
virtual T value(const Values& values) const {
|
||||
return function_(this->expressionA1_->value(values), boost::none);
|
||||
return function_(this->template expression<A1, 1>()->value(values), boost::none);
|
||||
}
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const {
|
||||
using boost::none;
|
||||
Augmented<A1> argument = this->expressionA1_->forward(values);
|
||||
Augmented<A1> argument = this->template expression<A1, 1>()->forward(values);
|
||||
JacobianTA dTdA;
|
||||
T t = function_(argument.value(),
|
||||
argument.constant() ? none : boost::optional<JacobianTA&>(dTdA));
|
||||
|
@ -605,7 +600,7 @@ public:
|
|||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = expressionA1_->traceExecution(values,
|
||||
A1 a1 = this-> template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
|
||||
return function_(a1, record->template jacobian<A1, 1>());
|
||||
|
@ -616,7 +611,7 @@ public:
|
|||
/// Binary Expression
|
||||
|
||||
template<class T, class A1, class A2>
|
||||
class BinaryExpression: public ExpressionNode<T> {
|
||||
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > {
|
||||
|
||||
public:
|
||||
|
||||
|
@ -629,13 +624,13 @@ public:
|
|||
private:
|
||||
|
||||
Function function_;
|
||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
||||
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
|
||||
|
||||
/// Constructor with a binary function f, and two input arguments
|
||||
BinaryExpression(Function f, //
|
||||
const Expression<A1>& e1, const Expression<A2>& e2) :
|
||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
||||
/// Constructor with a ternary function f, and three input arguments
|
||||
BinaryExpression(Function f, const Expression<A1>& e1,
|
||||
const Expression<A2>& e2) :
|
||||
function_(f) {
|
||||
this->template expression<A1, 1>() = e1.root();
|
||||
this->template expression<A2, 2>() = e2.root();
|
||||
ExpressionNode<T>::traceSize_ = //
|
||||
sizeof(Record) + e1.traceSize() + e2.traceSize();
|
||||
}
|
||||
|
@ -647,8 +642,8 @@ public:
|
|||
|
||||
/// Return keys that play in this expression
|
||||
virtual std::set<Key> keys() const {
|
||||
std::set<Key> keys1 = expressionA1_->keys();
|
||||
std::set<Key> keys2 = expressionA2_->keys();
|
||||
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
|
||||
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
|
||||
keys1.insert(keys2.begin(), keys2.end());
|
||||
return keys1;
|
||||
}
|
||||
|
@ -656,15 +651,16 @@ public:
|
|||
/// Return value
|
||||
virtual T value(const Values& values) const {
|
||||
using boost::none;
|
||||
return function_(this->expressionA1_->value(values),
|
||||
this->expressionA2_->value(values), none, none);
|
||||
return function_(this->template expression<A1, 1>()->value(values),
|
||||
this->template expression<A2, 2>()->value(values),
|
||||
none, none);
|
||||
}
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const {
|
||||
using boost::none;
|
||||
Augmented<A1> a1 = this->expressionA1_->forward(values);
|
||||
Augmented<A2> a2 = this->expressionA2_->forward(values);
|
||||
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
|
||||
JacobianTA1 dTdA1;
|
||||
JacobianTA2 dTdA2;
|
||||
T t = function_(a1.value(), a2.value(),
|
||||
|
@ -678,30 +674,29 @@ public:
|
|||
typedef Record<T, Arguments> Record;
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
/// The raw buffer is [Record | A1 raw | A2 raw]
|
||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||
char* raw) const {
|
||||
Record* record = new (raw) Record();
|
||||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = expressionA1_->traceExecution(values,
|
||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
raw = raw + expressionA1_->traceSize();
|
||||
A2 a2 = expressionA2_->traceExecution(values,
|
||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||
record->template trace<A2, 2>(), raw);
|
||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||
|
||||
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
||||
record->template jacobian<A2, 2>());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
/// Ternary Expression
|
||||
|
||||
template<class T, class A1, class A2, class A3>
|
||||
class TernaryExpression: public ExpressionNode<T> {
|
||||
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > {
|
||||
|
||||
public:
|
||||
|
||||
|
@ -715,15 +710,14 @@ public:
|
|||
private:
|
||||
|
||||
Function function_;
|
||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
||||
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
|
||||
boost::shared_ptr<ExpressionNode<A3> > expressionA3_;
|
||||
|
||||
/// Constructor with a ternary function f, and three input arguments
|
||||
TernaryExpression(Function f, const Expression<A1>& e1,
|
||||
const Expression<A2>& e2, const Expression<A3>& e3) :
|
||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
||||
e3.root()) {
|
||||
function_(f) {
|
||||
this->template expression<A1, 1>() = e1.root();
|
||||
this->template expression<A2, 2>() = e2.root();
|
||||
this->template expression<A3, 3>() = e3.root();
|
||||
ExpressionNode<T>::traceSize_ = //
|
||||
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize();
|
||||
}
|
||||
|
@ -734,9 +728,9 @@ public:
|
|||
|
||||
/// Return keys that play in this expression
|
||||
virtual std::set<Key> keys() const {
|
||||
std::set<Key> keys1 = expressionA1_->keys();
|
||||
std::set<Key> keys2 = expressionA2_->keys();
|
||||
std::set<Key> keys3 = expressionA3_->keys();
|
||||
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
|
||||
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
|
||||
std::set<Key> keys3 = this->template expression<A3, 3>()->keys();
|
||||
keys2.insert(keys3.begin(), keys3.end());
|
||||
keys1.insert(keys2.begin(), keys2.end());
|
||||
return keys1;
|
||||
|
@ -745,17 +739,18 @@ public:
|
|||
/// Return value
|
||||
virtual T value(const Values& values) const {
|
||||
using boost::none;
|
||||
return function_(this->expressionA1_->value(values),
|
||||
this->expressionA2_->value(values), this->expressionA3_->value(values),
|
||||
return function_(this->template expression<A1, 1>()->value(values),
|
||||
this->template expression<A2, 2>()->value(values),
|
||||
this->template expression<A3, 3>()->value(values),
|
||||
none, none, none);
|
||||
}
|
||||
|
||||
/// Return value and derivatives
|
||||
virtual Augmented<T> forward(const Values& values) const {
|
||||
using boost::none;
|
||||
Augmented<A1> a1 = this->expressionA1_->forward(values);
|
||||
Augmented<A2> a2 = this->expressionA2_->forward(values);
|
||||
Augmented<A3> a3 = this->expressionA3_->forward(values);
|
||||
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
|
||||
Augmented<A3> a3 = this->template expression<A3, 3>()->forward(values);
|
||||
JacobianTA1 dTdA1;
|
||||
JacobianTA2 dTdA2;
|
||||
JacobianTA3 dTdA3;
|
||||
|
@ -778,13 +773,13 @@ public:
|
|||
trace.setFunction(record);
|
||||
|
||||
raw = (char*) (record + 1);
|
||||
A1 a1 = expressionA1_->traceExecution(values,
|
||||
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||
record->template trace<A1, 1>(), raw);
|
||||
raw = raw + expressionA1_->traceSize();
|
||||
A2 a2 = expressionA2_->traceExecution(values,
|
||||
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||
record->template trace<A2, 2>(), raw);
|
||||
raw = raw + expressionA2_->traceSize();
|
||||
A3 a3 = expressionA3_->traceExecution(values,
|
||||
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
||||
record->template trace<A3, 3>(), raw);
|
||||
|
||||
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
||||
|
|
Loading…
Reference in New Issue