diff --git a/gtsam_unstable/nonlinear/Expression-inl.h b/gtsam_unstable/nonlinear/Expression-inl.h index c2f51ea96..06405579e 100644 --- a/gtsam_unstable/nonlinear/Expression-inl.h +++ b/gtsam_unstable/nonlinear/Expression-inl.h @@ -185,7 +185,7 @@ public: virtual Augmented forward(const Values& values) const = 0; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const = 0; + virtual T traceExecution(const Values& values, TracePtr& p) const = 0; }; //----------------------------------------------------------------------------- @@ -236,9 +236,10 @@ public: }; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const { + virtual T traceExecution(const Values& values, TracePtr& p) const { Trace* trace = new Trace(); - return std::make_pair(constant_, trace); + p = static_cast(trace); + return constant_; } }; @@ -299,10 +300,11 @@ public: }; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const { + virtual T traceExecution(const Values& values, TracePtr& p) const { Trace* trace = new Trace(); + p = static_cast(trace); trace->key = key_; - return std::make_pair(values.at(key_), trace); + return values.at(key_); } }; @@ -373,11 +375,11 @@ public: }; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const { - A a; + virtual T traceExecution(const Values& values, TracePtr& p) const { Trace* trace = new Trace(); - boost::tie(a, trace->trace) = this->expressionA_->traceExecution(values); - return std::make_pair(function_(a, trace->dTdA), trace); + p = static_cast(trace); + A a = this->expressionA_->traceExecution(values,trace->trace); + return function_(a, trace->dTdA); } }; @@ -465,13 +467,12 @@ public: }; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const { - A1 a1; - A2 a2; + virtual T traceExecution(const Values& values, TracePtr& p) const { Trace* trace = new Trace(); - boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values); - boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values); - return std::make_pair(function_(a1, a2, trace->dTdA1, trace->dTdA2), trace); + p = static_cast(trace); + A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1); + A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2); + return function_(a1, a2, trace->dTdA1, trace->dTdA2); } }; @@ -576,16 +577,13 @@ public: }; /// Construct an execution trace for reverse AD - virtual std::pair traceExecution(const Values& values) const { - A1 a1; - A2 a2; - A3 a3; + virtual T traceExecution(const Values& values, TracePtr& p) const { Trace* trace = new Trace(); - boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values); - boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values); - boost::tie(a3, trace->trace3) = this->expressionA3_->traceExecution(values); - return std::make_pair( - function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3), trace); + p = static_cast(trace); + A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1); + A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2); + A3 a3 = this->expressionA3_->traceExecution(values,trace->trace3); + return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3); } }; diff --git a/gtsam_unstable/nonlinear/Expression.h b/gtsam_unstable/nonlinear/Expression.h index cc7977a23..06265a9fb 100644 --- a/gtsam_unstable/nonlinear/Expression.h +++ b/gtsam_unstable/nonlinear/Expression.h @@ -117,9 +117,8 @@ public: Augmented augmented(const Values& values) const { #define REVERSE_AD #ifdef REVERSE_AD - T value; TracePtr trace; - boost::tie(value,trace) = root_->traceExecution(values); + T value = root_->traceExecution(values,trace); Augmented augmented(value); trace->reverseAD(augmented.jacobians()); delete trace;