diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 5ee1ca272..08a0e0bc6 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -392,7 +392,11 @@ class ExpressionNode { protected: - ExpressionNode() { + size_t traceSize_; + + /// Constructor, traceSize is size of the execution trace of expression rooted here + ExpressionNode(size_t traceSize = 0) : + traceSize_(traceSize) { } public: @@ -404,17 +408,17 @@ public: /// Return keys that play in this expression as a set virtual std::set keys() const = 0; + // Return size needed for memory buffer in traceExecution + size_t traceSize() const { + return traceSize_; + } + /// Return value virtual T value(const Values& values) const = 0; /// Return value and derivatives virtual Augmented forward(const Values& values) const = 0; - // Return size needed for memory buffer in traceExecution - virtual size_t traceSize() const { - return 0; - } - /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const = 0; @@ -519,8 +523,9 @@ private: boost::shared_ptr > expressionA1_; /// Constructor with a unary function f, and input argument e - UnaryExpression(Function f, const Expression& e) : - function_(f), expressionA1_(e.root()) { + UnaryExpression(Function f, const Expression& e1) : + ExpressionNode(sizeof(Record) + e1.traceSize()), // + function_(f), expressionA1_(e1.root()) { } friend class Expression ; @@ -551,11 +556,6 @@ public: typedef boost::mpl::vector > Arguments; typedef typename GenerateRecord::type Record; - // Return size needed for memory buffer in traceExecution - virtual size_t traceSize() const { - return sizeof(Record) + expressionA1_->traceSize(); - } - /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const { @@ -592,7 +592,8 @@ private: /// Constructor with a binary function f, and two input arguments BinaryExpression(Function f, // const Expression& e1, const Expression& e2) : - function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) { + ExpressionNode(sizeof(Record) + e1.traceSize() + e2.traceSize()), function_( + f), expressionA1_(e1.root()), expressionA2_(e2.root()) { } friend class Expression ; @@ -632,12 +633,6 @@ public: typedef boost::mpl::vector, Numbered > Arguments; typedef typename GenerateRecord::type Record; - // Return size needed for memory buffer in traceExecution - virtual size_t traceSize() const { - return sizeof(Record) + expressionA1_->traceSize() - + expressionA2_->traceSize(); - } - /// Construct an execution trace for reverse AD /// The raw buffer is [Record | A1 raw | A2 raw] virtual T traceExecution(const Values& values, ExecutionTrace& trace, @@ -682,7 +677,9 @@ private: Function f, // const Expression& e1, const Expression& e2, const Expression& e3) : - function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_( + ExpressionNode( + sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize()), function_( + f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_( e3.root()) { } @@ -729,12 +726,6 @@ public: typedef boost::mpl::vector, Numbered, Numbered > Arguments; typedef typename GenerateRecord::type Record; - // Return size needed for memory buffer in traceExecution - virtual size_t traceSize() const { - return sizeof(Record) + expressionA1_->traceSize() - + expressionA2_->traceSize() + expressionA2_->traceSize(); - } - /// Construct an execution trace for reverse AD virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const {