FunctionalNode inherited for all three functional ExpressionNode sub-classes

release/4.3a0
dellaert 2014-10-13 00:31:03 +02:00
parent 8100d89094
commit a9d9fcd241
1 changed files with 58 additions and 63 deletions

View File

@ -489,22 +489,16 @@ template<class T, class TYPES>
struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>,
GenerateRecord<T, MPL::_2, MPL::_1> >::type {
/// Access JacobianTrace
template<class A, size_t N>
JacobianTrace<typename Record::return_type, A, N>& jacobianTrace() {
return static_cast<JacobianTrace<T, A, N>&>(*this);
}
/// Access Trace
template<class A, size_t N>
ExecutionTrace<A>& trace() {
return jacobianTrace<A, N>().trace;
return static_cast<JacobianTrace<T, A, N>&>(*this).trace;
}
/// Access Jacobian
template<class A, size_t N>
Eigen::Matrix<double, T::dimension, A::dimension>& jacobian() {
return jacobianTrace<A, N>().dTdA;
return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA;
}
};
@ -534,20 +528,21 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
template<class T, class TYPES>
struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>,
GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type {
/// Access Expression
template<class A, size_t N>
boost::shared_ptr<ExpressionNode<A> > expression() {
return static_cast<Argument<T, A, N> &>(*this).expression;
}
/// Access Expression, const version
template<class A, size_t N>
boost::shared_ptr<ExpressionNode<A> > expression() const {
return static_cast<Argument<T, A, N> const &>(*this).expression;
}
};
/// Access Argument
template<class A, size_t N, class Record>
Argument<typename Record::return_type, A, N>& argument(Record& record) {
return static_cast<Argument<typename Record::return_type, A, N>&>(record);
}
/// Access Expression
template<class A, size_t N, class Record>
ExecutionTrace<A>& expression(Record* record) {
return argument<A, N>(*record).expression;
}
//-----------------------------------------------------------------------------
/// Unary Function Expression
@ -562,11 +557,11 @@ public:
private:
Function function_;
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
/// Constructor with a unary function f, and input argument e
UnaryExpression(Function f, const Expression<A1>& e1) :
function_(f), expressionA1_(e1.root()) {
function_(f) {
this->template expression<A1, 1>() = e1.root();
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
}
@ -576,18 +571,18 @@ public:
/// Return keys that play in this expression
virtual std::set<Key> keys() const {
return expressionA1_->keys();
return this->template expression<A1, 1>()->keys();
}
/// Return value
virtual T value(const Values& values) const {
return function_(this->expressionA1_->value(values), boost::none);
return function_(this->template expression<A1, 1>()->value(values), boost::none);
}
/// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const {
using boost::none;
Augmented<A1> argument = this->expressionA1_->forward(values);
Augmented<A1> argument = this->template expression<A1, 1>()->forward(values);
JacobianTA dTdA;
T t = function_(argument.value(),
argument.constant() ? none : boost::optional<JacobianTA&>(dTdA));
@ -605,7 +600,7 @@ public:
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = expressionA1_->traceExecution(values,
A1 a1 = this-> template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
return function_(a1, record->template jacobian<A1, 1>());
@ -616,7 +611,7 @@ public:
/// Binary Expression
template<class T, class A1, class A2>
class BinaryExpression: public ExpressionNode<T> {
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > {
public:
@ -629,13 +624,13 @@ public:
private:
Function function_;
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
/// Constructor with a binary function f, and two input arguments
BinaryExpression(Function f, //
const Expression<A1>& e1, const Expression<A2>& e2) :
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) {
/// Constructor with a ternary function f, and three input arguments
BinaryExpression(Function f, const Expression<A1>& e1,
const Expression<A2>& e2) :
function_(f) {
this->template expression<A1, 1>() = e1.root();
this->template expression<A2, 2>() = e2.root();
ExpressionNode<T>::traceSize_ = //
sizeof(Record) + e1.traceSize() + e2.traceSize();
}
@ -647,8 +642,8 @@ public:
/// Return keys that play in this expression
virtual std::set<Key> keys() const {
std::set<Key> keys1 = expressionA1_->keys();
std::set<Key> keys2 = expressionA2_->keys();
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
keys1.insert(keys2.begin(), keys2.end());
return keys1;
}
@ -656,15 +651,16 @@ public:
/// Return value
virtual T value(const Values& values) const {
using boost::none;
return function_(this->expressionA1_->value(values),
this->expressionA2_->value(values), none, none);
return function_(this->template expression<A1, 1>()->value(values),
this->template expression<A2, 2>()->value(values),
none, none);
}
/// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const {
using boost::none;
Augmented<A1> a1 = this->expressionA1_->forward(values);
Augmented<A2> a2 = this->expressionA2_->forward(values);
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
JacobianTA1 dTdA1;
JacobianTA2 dTdA2;
T t = function_(a1.value(), a2.value(),
@ -678,30 +674,29 @@ public:
typedef Record<T, Arguments> Record;
/// Construct an execution trace for reverse AD
/// The raw buffer is [Record | A1 raw | A2 raw]
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
char* raw) const {
Record* record = new (raw) Record();
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = expressionA1_->traceExecution(values,
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
raw = raw + expressionA1_->traceSize();
A2 a2 = expressionA2_->traceExecution(values,
raw = raw + this->template expression<A1, 1>()->traceSize();
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + this->template expression<A2, 2>()->traceSize();
return function_(a1, a2, record->template jacobian<A1, 1>(),
record->template jacobian<A2, 2>());
}
};
//-----------------------------------------------------------------------------
/// Ternary Expression
template<class T, class A1, class A2, class A3>
class TernaryExpression: public ExpressionNode<T> {
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > {
public:
@ -715,15 +710,14 @@ public:
private:
Function function_;
boost::shared_ptr<ExpressionNode<A1> > expressionA1_;
boost::shared_ptr<ExpressionNode<A2> > expressionA2_;
boost::shared_ptr<ExpressionNode<A3> > expressionA3_;
/// Constructor with a ternary function f, and three input arguments
TernaryExpression(Function f, const Expression<A1>& e1,
const Expression<A2>& e2, const Expression<A3>& e3) :
function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_(
e3.root()) {
function_(f) {
this->template expression<A1, 1>() = e1.root();
this->template expression<A2, 2>() = e2.root();
this->template expression<A3, 3>() = e3.root();
ExpressionNode<T>::traceSize_ = //
sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize();
}
@ -734,9 +728,9 @@ public:
/// Return keys that play in this expression
virtual std::set<Key> keys() const {
std::set<Key> keys1 = expressionA1_->keys();
std::set<Key> keys2 = expressionA2_->keys();
std::set<Key> keys3 = expressionA3_->keys();
std::set<Key> keys1 = this->template expression<A1, 1>()->keys();
std::set<Key> keys2 = this->template expression<A2, 2>()->keys();
std::set<Key> keys3 = this->template expression<A3, 3>()->keys();
keys2.insert(keys3.begin(), keys3.end());
keys1.insert(keys2.begin(), keys2.end());
return keys1;
@ -745,17 +739,18 @@ public:
/// Return value
virtual T value(const Values& values) const {
using boost::none;
return function_(this->expressionA1_->value(values),
this->expressionA2_->value(values), this->expressionA3_->value(values),
return function_(this->template expression<A1, 1>()->value(values),
this->template expression<A2, 2>()->value(values),
this->template expression<A3, 3>()->value(values),
none, none, none);
}
/// Return value and derivatives
virtual Augmented<T> forward(const Values& values) const {
using boost::none;
Augmented<A1> a1 = this->expressionA1_->forward(values);
Augmented<A2> a2 = this->expressionA2_->forward(values);
Augmented<A3> a3 = this->expressionA3_->forward(values);
Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values);
Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values);
Augmented<A3> a3 = this->template expression<A3, 3>()->forward(values);
JacobianTA1 dTdA1;
JacobianTA2 dTdA2;
JacobianTA3 dTdA3;
@ -778,13 +773,13 @@ public:
trace.setFunction(record);
raw = (char*) (record + 1);
A1 a1 = expressionA1_->traceExecution(values,
A1 a1 = this->template expression<A1, 1>()->traceExecution(values,
record->template trace<A1, 1>(), raw);
raw = raw + expressionA1_->traceSize();
A2 a2 = expressionA2_->traceExecution(values,
raw = raw + this->template expression<A1, 1>()->traceSize();
A2 a2 = this->template expression<A2, 2>()->traceExecution(values,
record->template trace<A2, 2>(), raw);
raw = raw + expressionA2_->traceSize();
A3 a3 = expressionA3_->traceExecution(values,
raw = raw + this->template expression<A2, 2>()->traceSize();
A3 a3 = this->template expression<A3, 3>()->traceExecution(values,
record->template trace<A3, 3>(), raw);
return function_(a1, a2, a3, record->template jacobian<A1, 1>(),