From a9d9fcd241ae117179bc09ed7baa698a09cf22a7 Mon Sep 17 00:00:00 2001 From: dellaert Date: Mon, 13 Oct 2014 00:31:03 +0200 Subject: [PATCH] FunctionalNode inherited for all three functional ExpressionNode sub-classes --- gtsam_unstable/nonlinear/Expression-inl.h | 121 +++++++++++----------- 1 file changed, 58 insertions(+), 63 deletions(-) diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index 6fef90d38..a765177aa 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -489,22 +489,16 @@ template struct Record: public boost::mpl::fold, GenerateRecord >::type { - /// Access JacobianTrace - template - JacobianTrace& jacobianTrace() { - return static_cast&>(*this); - } - /// Access Trace template ExecutionTrace& trace() { - return jacobianTrace().trace; + return static_cast&>(*this).trace; } /// Access Jacobian template Eigen::Matrix& jacobian() { - return jacobianTrace().dTdA; + return static_cast&>(*this).dTdA; } }; @@ -534,20 +528,21 @@ struct GenerateFunctionalNode: Argument, Base { template struct FunctionalNode: public boost::mpl::fold, GenerateFunctionalNode >::type { + + /// Access Expression + template + boost::shared_ptr > expression() { + return static_cast &>(*this).expression; + } + + /// Access Expression, const version + template + boost::shared_ptr > expression() const { + return static_cast const &>(*this).expression; + } + }; -/// Access Argument -template -Argument& argument(Record& record) { - return static_cast&>(record); -} - -/// Access Expression -template -ExecutionTrace& expression(Record* record) { - return argument(*record).expression; -} - //----------------------------------------------------------------------------- /// Unary Function Expression @@ -562,11 +557,11 @@ public: private: Function function_; - boost::shared_ptr > expressionA1_; /// Constructor with a unary function f, and input argument e UnaryExpression(Function f, const Expression& e1) : - function_(f), expressionA1_(e1.root()) { + function_(f) { + this->template expression() = e1.root(); ExpressionNode::traceSize_ = sizeof(Record) + e1.traceSize(); } @@ -576,18 +571,18 @@ public: /// Return keys that play in this expression virtual std::set keys() const { - return expressionA1_->keys(); + return this->template expression()->keys(); } /// Return value virtual T value(const Values& values) const { - return function_(this->expressionA1_->value(values), boost::none); + return function_(this->template expression()->value(values), boost::none); } /// Return value and derivatives virtual Augmented forward(const Values& values) const { using boost::none; - Augmented argument = this->expressionA1_->forward(values); + Augmented argument = this->template expression()->forward(values); JacobianTA dTdA; T t = function_(argument.value(), argument.constant() ? none : boost::optional(dTdA)); @@ -605,7 +600,7 @@ public: trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = expressionA1_->traceExecution(values, + A1 a1 = this-> template expression()->traceExecution(values, record->template trace(), raw); return function_(a1, record->template jacobian()); @@ -616,7 +611,7 @@ public: /// Binary Expression template -class BinaryExpression: public ExpressionNode { +class BinaryExpression: public FunctionalNode > { public: @@ -629,13 +624,13 @@ public: private: Function function_; - boost::shared_ptr > expressionA1_; - boost::shared_ptr > expressionA2_; - /// Constructor with a binary function f, and two input arguments - BinaryExpression(Function f, // - const Expression& e1, const Expression& e2) : - function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()) { + /// Constructor with a ternary function f, and three input arguments + BinaryExpression(Function f, const Expression& e1, + const Expression& e2) : + function_(f) { + this->template expression() = e1.root(); + this->template expression() = e2.root(); ExpressionNode::traceSize_ = // sizeof(Record) + e1.traceSize() + e2.traceSize(); } @@ -647,8 +642,8 @@ public: /// Return keys that play in this expression virtual std::set keys() const { - std::set keys1 = expressionA1_->keys(); - std::set keys2 = expressionA2_->keys(); + std::set keys1 = this->template expression()->keys(); + std::set keys2 = this->template expression()->keys(); keys1.insert(keys2.begin(), keys2.end()); return keys1; } @@ -656,15 +651,16 @@ public: /// Return value virtual T value(const Values& values) const { using boost::none; - return function_(this->expressionA1_->value(values), - this->expressionA2_->value(values), none, none); + return function_(this->template expression()->value(values), + this->template expression()->value(values), + none, none); } /// Return value and derivatives virtual Augmented forward(const Values& values) const { using boost::none; - Augmented a1 = this->expressionA1_->forward(values); - Augmented a2 = this->expressionA2_->forward(values); + Augmented a1 = this->template expression()->forward(values); + Augmented a2 = this->template expression()->forward(values); JacobianTA1 dTdA1; JacobianTA2 dTdA2; T t = function_(a1.value(), a2.value(), @@ -678,30 +674,29 @@ public: typedef Record Record; /// Construct an execution trace for reverse AD - /// The raw buffer is [Record | A1 raw | A2 raw] virtual T traceExecution(const Values& values, ExecutionTrace& trace, char* raw) const { Record* record = new (raw) Record(); trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = expressionA1_->traceExecution(values, + A1 a1 = this->template expression()->traceExecution(values, record->template trace(), raw); - raw = raw + expressionA1_->traceSize(); - A2 a2 = expressionA2_->traceExecution(values, + raw = raw + this->template expression()->traceSize(); + A2 a2 = this->template expression()->traceExecution(values, record->template trace(), raw); + raw = raw + this->template expression()->traceSize(); return function_(a1, a2, record->template jacobian(), record->template jacobian()); } - }; //----------------------------------------------------------------------------- /// Ternary Expression template -class TernaryExpression: public ExpressionNode { +class TernaryExpression: public FunctionalNode > { public: @@ -715,15 +710,14 @@ public: private: Function function_; - boost::shared_ptr > expressionA1_; - boost::shared_ptr > expressionA2_; - boost::shared_ptr > expressionA3_; /// Constructor with a ternary function f, and three input arguments TernaryExpression(Function f, const Expression& e1, const Expression& e2, const Expression& e3) : - function_(f), expressionA1_(e1.root()), expressionA2_(e2.root()), expressionA3_( - e3.root()) { + function_(f) { + this->template expression() = e1.root(); + this->template expression() = e2.root(); + this->template expression() = e3.root(); ExpressionNode::traceSize_ = // sizeof(Record) + e1.traceSize() + e2.traceSize() + e3.traceSize(); } @@ -734,9 +728,9 @@ public: /// Return keys that play in this expression virtual std::set keys() const { - std::set keys1 = expressionA1_->keys(); - std::set keys2 = expressionA2_->keys(); - std::set keys3 = expressionA3_->keys(); + std::set keys1 = this->template expression()->keys(); + std::set keys2 = this->template expression()->keys(); + std::set keys3 = this->template expression()->keys(); keys2.insert(keys3.begin(), keys3.end()); keys1.insert(keys2.begin(), keys2.end()); return keys1; @@ -745,17 +739,18 @@ public: /// Return value virtual T value(const Values& values) const { using boost::none; - return function_(this->expressionA1_->value(values), - this->expressionA2_->value(values), this->expressionA3_->value(values), + return function_(this->template expression()->value(values), + this->template expression()->value(values), + this->template expression()->value(values), none, none, none); } /// Return value and derivatives virtual Augmented forward(const Values& values) const { using boost::none; - Augmented a1 = this->expressionA1_->forward(values); - Augmented a2 = this->expressionA2_->forward(values); - Augmented a3 = this->expressionA3_->forward(values); + Augmented a1 = this->template expression()->forward(values); + Augmented a2 = this->template expression()->forward(values); + Augmented a3 = this->template expression()->forward(values); JacobianTA1 dTdA1; JacobianTA2 dTdA2; JacobianTA3 dTdA3; @@ -778,13 +773,13 @@ public: trace.setFunction(record); raw = (char*) (record + 1); - A1 a1 = expressionA1_->traceExecution(values, + A1 a1 = this->template expression()->traceExecution(values, record->template trace(), raw); - raw = raw + expressionA1_->traceSize(); - A2 a2 = expressionA2_->traceExecution(values, + raw = raw + this->template expression()->traceSize(); + A2 a2 = this->template expression()->traceExecution(values, record->template trace(), raw); - raw = raw + expressionA2_->traceSize(); - A3 a3 = expressionA3_->traceExecution(values, + raw = raw + this->template expression()->traceSize(); + A3 a3 = this->template expression()->traceExecution(values, record->template trace(), raw); return function_(a1, a2, a3, record->template jacobian(),