Pre-big collapse: prototype recursively defined inner Record2 type
							parent
							
								
									7fde47c48b
								
							
						
					
					
						commit
						bc9e11f43c
					
				|  | @ -307,10 +307,6 @@ struct Select<2, A> { | |||
| template<class T> | ||||
| class ExpressionNode { | ||||
| 
 | ||||
| public: | ||||
| 
 | ||||
|   static size_t const N = 0; // number of arguments
 | ||||
| 
 | ||||
| protected: | ||||
| 
 | ||||
|   size_t traceSize_; | ||||
|  | @ -510,7 +506,7 @@ struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>, | |||
| // to recursively generate a class, that will be the base for function nodes.
 | ||||
| // The class generated, for two arguments A1, A2, and A3 will be
 | ||||
| //
 | ||||
| // struct Base1 : Argument<T,A1,1>, ExpressionNode<T> {
 | ||||
| // struct Base1 : Argument<T,A1,1>, FunctionalBase<T> {
 | ||||
| //   ... storage related to A1 ...
 | ||||
| //   ... methods that work on A1 ...
 | ||||
| // };
 | ||||
|  | @ -535,7 +531,21 @@ struct Record: public boost::mpl::fold<TYPES, CallRecord<T::dimension>, | |||
| //-----------------------------------------------------------------------------
 | ||||
| 
 | ||||
| /**
 | ||||
|  * Building block for Recursive FunctionalNode Class | ||||
|  * Base case for recursive FunctionalNode class | ||||
|  */ | ||||
| template<class T> | ||||
| struct FunctionalBase: ExpressionNode<T> { | ||||
|   static size_t const N = 0; // number of arguments
 | ||||
| 
 | ||||
|   typedef CallRecord<T::dimension> Record2; | ||||
| 
 | ||||
|   /// Construct an execution trace for reverse AD
 | ||||
|   void trace(const Values& values, Record2* record, char*& raw) const { | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  * Building block for recursive FunctionalNode class | ||||
|  * The integer argument N is to guarantee a unique type signature, | ||||
|  * so we are guaranteed to be able to extract their values by static cast. | ||||
|  */ | ||||
|  | @ -562,34 +572,91 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base { | |||
|     return keys; | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Recursive Record Class for Functional Expressions | ||||
|    */ | ||||
|   struct Record2: JacobianTrace<T, A, N>, Base::Record2 { | ||||
| 
 | ||||
|     typedef T return_type; | ||||
|     typedef JacobianTrace<T, A, N> This; | ||||
| 
 | ||||
|     /// Print to std::cout
 | ||||
|     virtual void print(const std::string& indent) const { | ||||
|       Base::Record2::print(indent); | ||||
|       static const Eigen::IOFormat matlab(0, 1, " ", "; ", "", "", "[", "]"); | ||||
|       std::cout << This::dTdA.format(matlab) << std::endl; | ||||
|       This::trace.print(indent); | ||||
|     } | ||||
| 
 | ||||
|     /// Start the reverse AD process
 | ||||
|     virtual void startReverseAD(JacobianMap& jacobians) const { | ||||
|       Base::Record2::startReverseAD(jacobians); | ||||
|       Select<T::dimension, A>::reverseAD(This::trace, This::dTdA, jacobians); | ||||
|     } | ||||
| 
 | ||||
|     /// Given df/dT, multiply in dT/dA and continue reverse AD process
 | ||||
|     virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { | ||||
|       Base::Record2::reverseAD(dFdT, jacobians); | ||||
|       This::trace.reverseAD(dFdT * This::dTdA, jacobians); | ||||
|     } | ||||
| 
 | ||||
|     /// Version specialized to 2-dimensional output
 | ||||
|     typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T; | ||||
|     virtual void reverseAD2(const Jacobian2T& dFdT, | ||||
|         JacobianMap& jacobians) const { | ||||
|       Base::Record2::reverseAD2(dFdT, jacobians); | ||||
|       This::trace.reverseAD2(dFdT * This::dTdA, jacobians); | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   /// Construct an execution trace for reverse AD
 | ||||
|   void trace(const Values& values, Record2* record, char*& raw) const { | ||||
|     Base::trace(values, record, raw); | ||||
|     A a = This::expression->traceExecution(values, record->Record2::This::trace, raw); | ||||
|     raw = raw + This::expression->traceSize(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /**
 | ||||
|  *  Recursive GenerateFunctionalNode class Generator | ||||
|  */ | ||||
| template<class T, class TYPES> | ||||
| struct FunctionalNode: public boost::mpl::fold<TYPES, ExpressionNode<T>, | ||||
|     GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type { | ||||
| struct FunctionalNode { | ||||
|   typedef typename boost::mpl::fold<TYPES, FunctionalBase<T>, | ||||
|       GenerateFunctionalNode<T, MPL::_2, MPL::_1> >::type Base; | ||||
| 
 | ||||
|   /// Reset expression shared pointer
 | ||||
|   template<class A, size_t N> | ||||
|   void reset(const boost::shared_ptr<ExpressionNode<A> >& ptr) { | ||||
|     static_cast<Argument<T, A, N> &>(*this).expression = ptr; | ||||
|   } | ||||
|   struct type: public Base { | ||||
| 
 | ||||
|   /// Access Expression shared pointer
 | ||||
|   template<class A, size_t N> | ||||
|   boost::shared_ptr<ExpressionNode<A> > expression() const { | ||||
|     return static_cast<Argument<T, A, N> const &>(*this).expression; | ||||
|   } | ||||
|     /// Reset expression shared pointer
 | ||||
|     template<class A, size_t N> | ||||
|     void reset(const boost::shared_ptr<ExpressionNode<A> >& ptr) { | ||||
|       static_cast<Argument<T, A, N> &>(*this).expression = ptr; | ||||
|     } | ||||
| 
 | ||||
|     /// Access Expression shared pointer
 | ||||
|     template<class A, size_t N> | ||||
|     boost::shared_ptr<ExpressionNode<A> > expression() const { | ||||
|       return static_cast<Argument<T, A, N> const &>(*this).expression; | ||||
|     } | ||||
| 
 | ||||
|     /// Construct an execution trace for reverse AD
 | ||||
|     virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace, | ||||
|         char* raw) const { | ||||
|       typename Base::Record2* record = new (raw) typename Base::Record2(); | ||||
|       trace.setFunction(record); | ||||
|       raw = (char*) (record + 1); | ||||
| 
 | ||||
|       this->trace(values, record, raw); | ||||
| 
 | ||||
|       return T(); // TODO
 | ||||
|     } | ||||
|   }; | ||||
| }; | ||||
| 
 | ||||
| //-----------------------------------------------------------------------------
 | ||||
| 
 | ||||
| /// Unary Function Expression
 | ||||
| template<class T, class A1> | ||||
| class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> > { | ||||
| class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type { | ||||
| 
 | ||||
|   /// The automatically generated Base class
 | ||||
|   typedef FunctionalNode<T, boost::mpl::vector<A1> > Base; | ||||
|  | @ -636,10 +703,11 @@ public: | |||
|       char* raw) const { | ||||
|     Record* record = new (raw) Record(); | ||||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
|     A1 a1 = this-> template expression<A1, 1>()->traceExecution(values, | ||||
| 
 | ||||
|     A1 a1 = this->template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
|     raw = raw + this->template expression<A1, 1>()->traceSize(); | ||||
| 
 | ||||
|     return function_(a1, record->template jacobian<A1, 1>()); | ||||
|   } | ||||
|  | @ -649,7 +717,7 @@ public: | |||
| /// Binary Expression
 | ||||
| 
 | ||||
| template<class T, class A1, class A2> | ||||
| class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> > { | ||||
| class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> >::type { | ||||
| 
 | ||||
| public: | ||||
| 
 | ||||
|  | @ -706,13 +774,15 @@ public: | |||
|       char* raw) const { | ||||
|     Record* record = new (raw) Record(); | ||||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
| 
 | ||||
|     A1 a1 = this->template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
|     raw = raw + this->template expression<A1, 1>()->traceSize(); | ||||
| 
 | ||||
|     A2 a2 = this->template expression<A2, 2>()->traceExecution(values, | ||||
|         record->template trace<A2, 2>(), raw); | ||||
|     raw = raw + this->template expression<A2, 2>()->traceSize(); | ||||
| 
 | ||||
|     return function_(a1, a2, record->template jacobian<A1, 1>(), | ||||
|         record->template jacobian<A2, 2>()); | ||||
|  | @ -723,7 +793,7 @@ public: | |||
| /// Ternary Expression
 | ||||
| 
 | ||||
| template<class T, class A1, class A2, class A3> | ||||
| class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> > { | ||||
| class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type { | ||||
| 
 | ||||
| public: | ||||
| 
 | ||||
|  | @ -786,16 +856,19 @@ public: | |||
|       char* raw) const { | ||||
|     Record* record = new (raw) Record(); | ||||
|     trace.setFunction(record); | ||||
| 
 | ||||
|     raw = (char*) (record + 1); | ||||
| 
 | ||||
|     A1 a1 = this->template expression<A1, 1>()->traceExecution(values, | ||||
|         record->template trace<A1, 1>(), raw); | ||||
|     raw = raw + this->template expression<A1, 1>()->traceSize(); | ||||
| 
 | ||||
|     A2 a2 = this->template expression<A2, 2>()->traceExecution(values, | ||||
|         record->template trace<A2, 2>(), raw); | ||||
|     raw = raw + this->template expression<A2, 2>()->traceSize(); | ||||
| 
 | ||||
|     A3 a3 = this->template expression<A3, 3>()->traceExecution(values, | ||||
|         record->template trace<A3, 3>(), raw); | ||||
|     raw = raw + this->template expression<A3, 3>()->traceSize(); | ||||
| 
 | ||||
|     return function_(a1, a2, a3, record->template jacobian<A1, 1>(), | ||||
|         record->template jacobian<A2, 2>(), record->template jacobian<A3, 3>()); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue