FunctionalNode inherited for all three functional ExpressionNode sub-classes
							parent
							
								
									8100d89094
								
							
						
					
					
						commit
						a9d9fcd241
					
				|  | @ -489,22 +489,16 @@ template<class T, class TYPES> | |||
| struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>, | ||||
|     GenerateRecord<T, MPL::_2, MPL::_1> >::type { | ||||
| 
 | ||||
|   /// Access JacobianTrace
 | ||||
|   template<class A, size_t N> | ||||
|   JacobianTrace<typename Record::return_type, A, N>& jacobianTrace() { | ||||
|     return static_cast<JacobianTrace<T, A, N>&>(*this); | ||||
|   } | ||||
| 
 | ||||
|   /// Access Trace
 | ||||
|   template<class A, size_t N> | ||||
|   ExecutionTrace<A>& trace() { | ||||
|     return jacobianTrace<A, N>().trace; | ||||
|     return static_cast<JacobianTrace<T, A, N>&>(*this).trace; | ||||
|   } | ||||
| 
 | ||||
|   /// Access Jacobian
 | ||||
|   template<class A, size_t N> | ||||
|   Eigen::Matrix<double, T::dimension, A::dimension>& jacobian() { | ||||
|     return jacobianTrace<A, N>().dTdA; | ||||
|     return static_cast<JacobianTrace<T, A, N>&>(*this).dTdA; | ||||
|   } | ||||
| 
 | ||||
| }; | ||||
|  | @ -534,20 +528,21 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base { | |||
| template<class T, class TYPES> | ||||
| struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>, | ||||
|     GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type { | ||||
| 
 | ||||
|   /// Access Expression
 | ||||
|   template<class A, size_t N> | ||||
|   boost::shared_ptr<ExpressionNode<A> > expression() { | ||||
|     return static_cast<Argument<T, A, N> &>(*this).expression; | ||||
|   } | ||||
| 
 | ||||
|   /// Access Expression, const version
 | ||||
|   template<class A, size_t N> | ||||
|   boost::shared_ptr<ExpressionNode<A> > expression() const { | ||||
|     return static_cast<Argument<T, A, N> const &>(*this).expression; | ||||
|   } | ||||
| 
 | ||||
| }; | ||||
| 
 | ||||
| /// Access Argument
 | ||||
| template<class A, size_t N, class Record> | ||||
| Argument<typename Record::return_type, A, N>& argument(Record& record) { | ||||
|   return static_cast<Argument<typename Record::return_type, A, N>&>(record); | ||||
| } | ||||
| 
 | ||||
| /// Access Expression
 | ||||
| template<class A, size_t N, class Record> | ||||
| ExecutionTrace<A>& expression(Record* record) { | ||||
|   return argument<A, N>(*record).expression; | ||||
| } | ||||
| 
 | ||||
| //-----------------------------------------------------------------------------
 | ||||
| 
 | ||||
| /// Unary Function Expression
 | ||||
|  | @ -562,11 +557,11 @@ public: | |||
| private: | ||||
| 
 | ||||
|   Function function_; | ||||
|   boost::shared_ptr<ExpressionNode<A1> > expressionA1_; | ||||
| 
 | ||||
|   /// Constructor with a unary function f, and input argument e
 | ||||
|   UnaryExpression(Function f, const Expression<A1>& e1) : | ||||
|       function_(f), expressionA1_(e1.root()) { | ||||
|       function_(f) { | ||||
|     this->template expression<A1, 1>() = e1.root(); | ||||
|     ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -576,18 +571,18 @@ public: | |||
| 
 | ||||
|   /// Return keys that play in this expression
 | ||||
|   virtual std::set<Key> keys() const { | ||||
|     return expressionA1_->keys(); | ||||
|     return this->template expression<A1, 1>()->keys(); | ||||
|   } | ||||
| 
 | ||||
|   /// Return value
 | ||||
|   virtual T value(const Values& values) const { | ||||
|     return function_(this->expressionA1_->value(values), boost::none); | ||||
|     return function_(this->template expression<A1, 1>()->value(values), boost::none); | ||||
|   } | ||||
| 
 | ||||
|   /// Return value and derivatives
 | ||||
|   virtual Augmented<T> forward(const Values& values) const { | ||||
|     using boost::none; | ||||
|     Augmented<A1> argument = this->expressionA1_->forward(values); | ||||
|     Augmented<A1> argument = this->template expression<A1, 1>()->forward(values); | ||||
|     JacobianTA dTdA; | ||||
|     T t = function_(argument.value(), | ||||
|         argument.constant() ? none : boost::optional<JacobianTA&>(dTdA)); | ||||
|  | @ -605,7 +600,7 @@ public: | |||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
|     A1 a1 = expressionA1_->traceExecution(values, | ||||
|     A1 a1 = this-> template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
| 
 | ||||
|     return function_(a1, record->template jacobian<A1, 1>()); | ||||
|  | @ -616,7 +611,7 @@ public: | |||
| /// Binary Expression
 | ||||
| 
 | ||||
| template<class T, class A1, class A2> | ||||
| class BinaryExpression: public ExpressionNode<T> { | ||||
| class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > { | ||||
| 
 | ||||
| public: | ||||
| 
 | ||||
|  | @ -629,13 +624,13 @@ public: | |||
| private: | ||||
| 
 | ||||
|   Function function_; | ||||
|   boost::shared_ptr<ExpressionNode<A1> > expressionA1_; | ||||
|   boost::shared_ptr<ExpressionNode<A2> > expressionA2_; | ||||
| 
 | ||||
|   /// Constructor with a binary function f, and two input arguments
 | ||||
|   BinaryExpression(Function f, //
 | ||||
|       const Expression<A1>& e1, const Expression<A2>& e2) : | ||||
|       function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) { | ||||
|   /// Constructor with a ternary function f, and three input arguments
 | ||||
|   BinaryExpression(Function f, const Expression<A1>& e1, | ||||
|       const Expression<A2>& e2) : | ||||
|       function_(f) { | ||||
|     this->template expression<A1, 1>() = e1.root(); | ||||
|     this->template expression<A2, 2>() = e2.root(); | ||||
|     ExpressionNode<T>::traceSize_ = //
 | ||||
|         sizeof(Record) + e1.traceSize() + e2.traceSize(); | ||||
|   } | ||||
|  | @ -647,8 +642,8 @@ public: | |||
| 
 | ||||
|   /// Return keys that play in this expression
 | ||||
|   virtual std::set<Key> keys() const { | ||||
|     std::set<Key> keys1 = expressionA1_->keys(); | ||||
|     std::set<Key> keys2 = expressionA2_->keys(); | ||||
|     std::set<Key> keys1 = this->template expression<A1, 1>()->keys(); | ||||
|     std::set<Key> keys2 = this->template expression<A2, 2>()->keys(); | ||||
|     keys1.insert(keys2.begin(), keys2.end()); | ||||
|     return keys1; | ||||
|   } | ||||
|  | @ -656,15 +651,16 @@ public: | |||
|   /// Return value
 | ||||
|   virtual T value(const Values& values) const { | ||||
|     using boost::none; | ||||
|     return function_(this->expressionA1_->value(values), | ||||
|         this->expressionA2_->value(values), none, none); | ||||
|     return function_(this->template expression<A1, 1>()->value(values), | ||||
|         this->template expression<A2, 2>()->value(values), | ||||
|         none, none); | ||||
|   } | ||||
| 
 | ||||
|   /// Return value and derivatives
 | ||||
|   virtual Augmented<T> forward(const Values& values) const { | ||||
|     using boost::none; | ||||
|     Augmented<A1> a1 = this->expressionA1_->forward(values); | ||||
|     Augmented<A2> a2 = this->expressionA2_->forward(values); | ||||
|     Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values); | ||||
|     Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values); | ||||
|     JacobianTA1 dTdA1; | ||||
|     JacobianTA2 dTdA2; | ||||
|     T t = function_(a1.value(), a2.value(), | ||||
|  | @ -678,30 +674,29 @@ public: | |||
|   typedef Record<T, Arguments> Record; | ||||
| 
 | ||||
|   /// Construct an execution trace for reverse AD
 | ||||
|   /// The raw buffer is [Record | A1 raw | A2 raw]
 | ||||
|   virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | ||||
|       char* raw) const { | ||||
|     Record* record = new (raw) Record(); | ||||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
|     A1 a1 = expressionA1_->traceExecution(values, | ||||
|     A1 a1 = this->template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
|     raw = raw + expressionA1_->traceSize(); | ||||
|     A2 a2 = expressionA2_->traceExecution(values, | ||||
|     raw = raw + this->template expression<A1, 1>()->traceSize(); | ||||
|     A2 a2 = this->template expression<A2, 2>()->traceExecution(values, | ||||
|         record->template trace<A2, 2>(), raw); | ||||
|     raw = raw + this->template expression<A2, 2>()->traceSize(); | ||||
| 
 | ||||
|     return function_(a1, a2, record->template jacobian<A1, 1>(), | ||||
|         record->template jacobian<A2, 2>()); | ||||
|   } | ||||
| 
 | ||||
| }; | ||||
| 
 | ||||
| //-----------------------------------------------------------------------------
 | ||||
| /// Ternary Expression
 | ||||
| 
 | ||||
| template<class T, class A1, class A2, class A3> | ||||
| class TernaryExpression: public ExpressionNode<T> { | ||||
| class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > { | ||||
| 
 | ||||
| public: | ||||
| 
 | ||||
|  | @ -715,15 +710,14 @@ public: | |||
| private: | ||||
| 
 | ||||
|   Function function_; | ||||
|   boost::shared_ptr<ExpressionNode<A1> > expressionA1_; | ||||
|   boost::shared_ptr<ExpressionNode<A2> > expressionA2_; | ||||
|   boost::shared_ptr<ExpressionNode<A3> > expressionA3_; | ||||
| 
 | ||||
|   /// Constructor with a ternary function f, and three input arguments
 | ||||
|   TernaryExpression(Function f, const Expression<A1>& e1, | ||||
|       const Expression<A2>& e2, const Expression<A3>& e3) : | ||||
|       function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_( | ||||
|           e3.root()) { | ||||
|       function_(f) { | ||||
|     this->template expression<A1, 1>() = e1.root(); | ||||
|     this->template expression<A2, 2>() = e2.root(); | ||||
|     this->template expression<A3, 3>() = e3.root(); | ||||
|     ExpressionNode<T>::traceSize_ = //
 | ||||
|         sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize(); | ||||
|   } | ||||
|  | @ -734,9 +728,9 @@ public: | |||
| 
 | ||||
|   /// Return keys that play in this expression
 | ||||
|   virtual std::set<Key> keys() const { | ||||
|     std::set<Key> keys1 = expressionA1_->keys(); | ||||
|     std::set<Key> keys2 = expressionA2_->keys(); | ||||
|     std::set<Key> keys3 = expressionA3_->keys(); | ||||
|     std::set<Key> keys1 = this->template expression<A1, 1>()->keys(); | ||||
|     std::set<Key> keys2 = this->template expression<A2, 2>()->keys(); | ||||
|     std::set<Key> keys3 = this->template expression<A3, 3>()->keys(); | ||||
|     keys2.insert(keys3.begin(), keys3.end()); | ||||
|     keys1.insert(keys2.begin(), keys2.end()); | ||||
|     return keys1; | ||||
|  | @ -745,17 +739,18 @@ public: | |||
|   /// Return value
 | ||||
|   virtual T value(const Values& values) const { | ||||
|     using boost::none; | ||||
|     return function_(this->expressionA1_->value(values), | ||||
|         this->expressionA2_->value(values), this->expressionA3_->value(values), | ||||
|     return function_(this->template expression<A1, 1>()->value(values), | ||||
|         this->template expression<A2, 2>()->value(values), | ||||
|         this->template expression<A3, 3>()->value(values), | ||||
|         none, none, none); | ||||
|   } | ||||
| 
 | ||||
|   /// Return value and derivatives
 | ||||
|   virtual Augmented<T> forward(const Values& values) const { | ||||
|     using boost::none; | ||||
|     Augmented<A1> a1 = this->expressionA1_->forward(values); | ||||
|     Augmented<A2> a2 = this->expressionA2_->forward(values); | ||||
|     Augmented<A3> a3 = this->expressionA3_->forward(values); | ||||
|     Augmented<A1> a1 = this->template expression<A1, 1>()->forward(values); | ||||
|     Augmented<A2> a2 = this->template expression<A2, 2>()->forward(values); | ||||
|     Augmented<A3> a3 = this->template expression<A3, 3>()->forward(values); | ||||
|     JacobianTA1 dTdA1; | ||||
|     JacobianTA2 dTdA2; | ||||
|     JacobianTA3 dTdA3; | ||||
|  | @ -778,13 +773,13 @@ public: | |||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
|     A1 a1 = expressionA1_->traceExecution(values, | ||||
|     A1 a1 = this->template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
|     raw = raw + expressionA1_->traceSize(); | ||||
|     A2 a2 = expressionA2_->traceExecution(values, | ||||
|     raw = raw + this->template expression<A1, 1>()->traceSize(); | ||||
|     A2 a2 = this->template expression<A2, 2>()->traceExecution(values, | ||||
|         record->template trace<A2, 2>(), raw); | ||||
|     raw = raw + expressionA2_->traceSize(); | ||||
|     A3 a3 = expressionA3_->traceExecution(values, | ||||
|     raw = raw + this->template expression<A2, 2>()->traceSize(); | ||||
|     A3 a3 = this->template expression<A3, 3>()->traceExecution(values, | ||||
|         record->template trace<A3, 3>(), raw); | ||||
| 
 | ||||
|     return function_(a1, a2, a3, record->template jacobian<A1, 1>(), | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue