FunctionalNode inherited for all three functional ExpressionNode sub-classes
parent
8100d89094
commit
a9d9fcd241
|
@ -489,22 +489,16 @@ template<class T, class TYPES>
|
||||||
struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
|
||||||
GenerateRecord<T, MPL::_2, MPL::_1> >::type {
|
GenerateRecord<T, MPL::_2, MPL::_1> >::type {
|
||||||
|
|
||||||
/// Access JacobianTrace
|
|
||||||
template<class A, size_t N>
|
|
||||||
JacobianTrace<typename Record::return_type, A, N>& jacobianTrace() {
|
|
||||||
return static_cast<JacobianTrace<T, A, N>&>(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Access Trace
|
/// Access Trace
|
||||||
template<class A, size_t N>
|
template<class A, size_t N>
|
||||||
ExecutionTrace<A>& trace() {
|
ExecutionTrace<A>& trace() {
|
||||||
return jacobianTrace<A, N>().trace;
|
return static_cast<JacobianTrace<T, A, N>&>(*this).trace;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Access Jacobian
|
/// Access Jacobian
|
||||||
template<class A, size_t N>
|
template<class A, size_t N>
|
||||||
Eigen::Matrix<double, T::dimension, A::dimension>& jacobian() {
|
Eigen::Matrix<double, T::dimension, A::dimension>& jacobian() {
|
||||||
return jacobianTrace<A, N>().dTdA;
|
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -534,20 +528,21 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
template<class T, class TYPES>
|
template<class T, class TYPES>
|
||||||
struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
|
struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
|
||||||
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type {
|
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type {
|
||||||
};
|
|
||||||
|
|
||||||
/// Access Argument
|
|
||||||
template<class A, size_t N, class Record>
|
|
||||||
Argument<typename Record::return_type, A, N>& argument(Record& record) {
|
|
||||||
return static_cast<Argument<typename Record::return_type, A, N>&>(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Access Expression
|
/// Access Expression
|
||||||
template<class A, size_t N, class Record>
|
template<class A, size_t N>
|
||||||
ExecutionTrace<A>& expression(Record* record) {
|
boost::shared_ptr<ExpressionNode<A> > expression() {
|
||||||
return argument<A, N>(*record).expression;
|
return static_cast<Argument<T, A, N> &>(*this).expression;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Access Expression, const version
|
||||||
|
template<class A, size_t N>
|
||||||
|
boost::shared_ptr<ExpressionNode<A> > expression() const {
|
||||||
|
return static_cast<Argument<T, A, N> const &>(*this).expression;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Unary Function Expression
|
/// Unary Function Expression
|
||||||
|
@ -562,11 +557,11 @@ public:
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
|
||||||
|
|
||||||
/// Constructor with a unary function f, and input argument e
|
/// Constructor with a unary function f, and input argument e
|
||||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||||
function_(f), expressionA1_(e1.root()) {
|
function_(f) {
|
||||||
|
this->template expression<A1, 1>() = e1.root();
|
||||||
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
|
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -576,18 +571,18 @@ public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
return expressionA1_->keys();
|
return this->template expression<A1, 1>()->keys();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
return function_(this->expressionA1_->value(values), boost::none);
|
return function_(this->template expression<A1, 1>()->value(values), boost::none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return value and derivatives
|
/// Return value and derivatives
|
||||||
virtual Augmented<T> forward(const Values& values) const {
|
virtual Augmented<T> forward(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
Augmented<A1> argument = this->expressionA1_->forward(values);
|
Augmented<A1> argument = this->template expression<A1, 1>()->forward(values);
|
||||||
JacobianTA dTdA;
|
JacobianTA dTdA;
|
||||||
T t = function_(argument.value(),
|
T t = function_(argument.value(),
|
||||||
argument.constant() ? none : boost::optional<JacobianTA&>(dTdA));
|
argument.constant() ? none : boost::optional<JacobianTA&>(dTdA));
|
||||||
|
@ -605,7 +600,7 @@ public:
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
A1 a1 = expressionA1_->traceExecution(values,
|
A1 a1 = this-> template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
|
|
||||||
return function_(a1, record->template jacobian<A1, 1>());
|
return function_(a1, record->template jacobian<A1, 1>());
|
||||||
|
@ -616,7 +611,7 @@ public:
|
||||||
/// Binary Expression
|
/// Binary Expression
|
||||||
|
|
||||||
template<class T, class A1, class A2>
|
template<class T, class A1, class A2>
|
||||||
class BinaryExpression: public ExpressionNode<T> {
|
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -629,13 +624,13 @@ public:
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
|
||||||
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
|
|
||||||
|
|
||||||
/// Constructor with a binary function f, and two input arguments
|
/// Constructor with a ternary function f, and three input arguments
|
||||||
BinaryExpression(Function f, //
|
BinaryExpression(Function f, const Expression<A1>& e1,
|
||||||
const Expression<A1>& e1, const Expression<A2>& e2) :
|
const Expression<A2>& e2) :
|
||||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
|
function_(f) {
|
||||||
|
this->template expression<A1, 1>() = e1.root();
|
||||||
|
this->template expression<A2, 2>() = e2.root();
|
||||||
ExpressionNode<T>::traceSize_ = //
|
ExpressionNode<T>::traceSize_ = //
|
||||||
sizeof(Record) + e1.traceSize() + e2.traceSize();
|
sizeof(Record) + e1.traceSize() + e2.traceSize();
|
||||||
}
|
}
|
||||||
|
@ -647,8 +642,8 @@ public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
std::set<Key> keys1 = expressionA1_->keys();
|
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
|
||||||
std::set<Key> keys2 = expressionA2_->keys();
|
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
|
||||||
keys1.insert(keys2.begin(), keys2.end());
|
keys1.insert(keys2.begin(), keys2.end());
|
||||||
return keys1;
|
return keys1;
|
||||||
}
|
}
|
||||||
|
@ -656,15 +651,16 @@ public:
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
return function_(this->expressionA1_->value(values),
|
return function_(this->template expression<A1, 1>()->value(values),
|
||||||
this->expressionA2_->value(values), none, none);
|
this->template expression<A2, 2>()->value(values),
|
||||||
|
none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return value and derivatives
|
/// Return value and derivatives
|
||||||
virtual Augmented<T> forward(const Values& values) const {
|
virtual Augmented<T> forward(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
Augmented<A1> a1 = this->expressionA1_->forward(values);
|
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||||
Augmented<A2> a2 = this->expressionA2_->forward(values);
|
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
|
||||||
JacobianTA1 dTdA1;
|
JacobianTA1 dTdA1;
|
||||||
JacobianTA2 dTdA2;
|
JacobianTA2 dTdA2;
|
||||||
T t = function_(a1.value(), a2.value(),
|
T t = function_(a1.value(), a2.value(),
|
||||||
|
@ -678,30 +674,29 @@ public:
|
||||||
typedef Record<T, Arguments> Record;
|
typedef Record<T, Arguments> Record;
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
/// The raw buffer is [Record | A1 raw | A2 raw]
|
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
Record* record = new (raw) Record();
|
Record* record = new (raw) Record();
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
A1 a1 = expressionA1_->traceExecution(values,
|
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
raw = raw + expressionA1_->traceSize();
|
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||||
A2 a2 = expressionA2_->traceExecution(values,
|
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||||
record->template trace<A2, 2>(), raw);
|
record->template trace<A2, 2>(), raw);
|
||||||
|
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||||
|
|
||||||
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
return function_(a1, a2, record->template jacobian<A1, 1>(),
|
||||||
record->template jacobian<A2, 2>());
|
record->template jacobian<A2, 2>());
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
/// Ternary Expression
|
/// Ternary Expression
|
||||||
|
|
||||||
template<class T, class A1, class A2, class A3>
|
template<class T, class A1, class A2, class A3>
|
||||||
class TernaryExpression: public ExpressionNode<T> {
|
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -715,15 +710,14 @@ public:
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
|
|
||||||
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
|
|
||||||
boost::shared_ptr<ExpressionNode<A3> > expressionA3_;
|
|
||||||
|
|
||||||
/// Constructor with a ternary function f, and three input arguments
|
/// Constructor with a ternary function f, and three input arguments
|
||||||
TernaryExpression(Function f, const Expression<A1>& e1,
|
TernaryExpression(Function f, const Expression<A1>& e1,
|
||||||
const Expression<A2>& e2, const Expression<A3>& e3) :
|
const Expression<A2>& e2, const Expression<A3>& e3) :
|
||||||
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
|
function_(f) {
|
||||||
e3.root()) {
|
this->template expression<A1, 1>() = e1.root();
|
||||||
|
this->template expression<A2, 2>() = e2.root();
|
||||||
|
this->template expression<A3, 3>() = e3.root();
|
||||||
ExpressionNode<T>::traceSize_ = //
|
ExpressionNode<T>::traceSize_ = //
|
||||||
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize();
|
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize();
|
||||||
}
|
}
|
||||||
|
@ -734,9 +728,9 @@ public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
std::set<Key> keys1 = expressionA1_->keys();
|
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
|
||||||
std::set<Key> keys2 = expressionA2_->keys();
|
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
|
||||||
std::set<Key> keys3 = expressionA3_->keys();
|
std::set<Key> keys3 = this->template expression<A3, 3>()->keys();
|
||||||
keys2.insert(keys3.begin(), keys3.end());
|
keys2.insert(keys3.begin(), keys3.end());
|
||||||
keys1.insert(keys2.begin(), keys2.end());
|
keys1.insert(keys2.begin(), keys2.end());
|
||||||
return keys1;
|
return keys1;
|
||||||
|
@ -745,17 +739,18 @@ public:
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
return function_(this->expressionA1_->value(values),
|
return function_(this->template expression<A1, 1>()->value(values),
|
||||||
this->expressionA2_->value(values), this->expressionA3_->value(values),
|
this->template expression<A2, 2>()->value(values),
|
||||||
|
this->template expression<A3, 3>()->value(values),
|
||||||
none, none, none);
|
none, none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return value and derivatives
|
/// Return value and derivatives
|
||||||
virtual Augmented<T> forward(const Values& values) const {
|
virtual Augmented<T> forward(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
Augmented<A1> a1 = this->expressionA1_->forward(values);
|
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
|
||||||
Augmented<A2> a2 = this->expressionA2_->forward(values);
|
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
|
||||||
Augmented<A3> a3 = this->expressionA3_->forward(values);
|
Augmented<A3> a3 = this->template expression<A3, 3>()->forward(values);
|
||||||
JacobianTA1 dTdA1;
|
JacobianTA1 dTdA1;
|
||||||
JacobianTA2 dTdA2;
|
JacobianTA2 dTdA2;
|
||||||
JacobianTA3 dTdA3;
|
JacobianTA3 dTdA3;
|
||||||
|
@ -778,13 +773,13 @@ public:
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
raw = (char*) (record + 1);
|
raw = (char*) (record + 1);
|
||||||
A1 a1 = expressionA1_->traceExecution(values,
|
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
|
||||||
record->template trace<A1, 1>(), raw);
|
record->template trace<A1, 1>(), raw);
|
||||||
raw = raw + expressionA1_->traceSize();
|
raw = raw + this->template expression<A1, 1>()->traceSize();
|
||||||
A2 a2 = expressionA2_->traceExecution(values,
|
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
|
||||||
record->template trace<A2, 2>(), raw);
|
record->template trace<A2, 2>(), raw);
|
||||||
raw = raw + expressionA2_->traceSize();
|
raw = raw + this->template expression<A2, 2>()->traceSize();
|
||||||
A3 a3 = expressionA3_->traceExecution(values,
|
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
|
||||||
record->template trace<A3, 3>(), raw);
|
record->template trace<A3, 3>(), raw);
|
||||||
|
|
||||||
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),
|
||||||
|
|
Loading…
Reference in New Issue