Avoid argument temporaries
parent
84987aa351
commit
c4a92acde1
|
@ -185,7 +185,7 @@ public:
|
|||
virtual Augmented<T> forward(const Values& values) const = 0;
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const = 0;
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const = 0;
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
|
@ -236,9 +236,10 @@ public:
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||
Trace* trace = new Trace();
|
||||
return std::make_pair(constant_, trace);
|
||||
p = static_cast<TracePtr>(trace);
|
||||
return constant_;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -299,10 +300,11 @@ public:
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||
Trace* trace = new Trace();
|
||||
p = static_cast<TracePtr>(trace);
|
||||
trace->key = key_;
|
||||
return std::make_pair(values.at<T>(key_), trace);
|
||||
return values.at<T>(key_);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -373,11 +375,11 @@ public:
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
A a;
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a, trace->trace) = this->expressionA_->traceExecution(values);
|
||||
return std::make_pair(function_(a, trace->dTdA), trace);
|
||||
p = static_cast<TracePtr>(trace);
|
||||
A a = this->expressionA_->traceExecution(values,trace->trace);
|
||||
return function_(a, trace->dTdA);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -465,13 +467,12 @@ public:
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
A1 a1;
|
||||
A2 a2;
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
||||
return std::make_pair(function_(a1, a2, trace->dTdA1, trace->dTdA2), trace);
|
||||
p = static_cast<TracePtr>(trace);
|
||||
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
|
||||
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
|
||||
return function_(a1, a2, trace->dTdA1, trace->dTdA2);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -576,16 +577,13 @@ public:
|
|||
};
|
||||
|
||||
/// Construct an execution trace for reverse AD
|
||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
||||
A1 a1;
|
||||
A2 a2;
|
||||
A3 a3;
|
||||
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||
Trace* trace = new Trace();
|
||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
||||
boost::tie(a3, trace->trace3) = this->expressionA3_->traceExecution(values);
|
||||
return std::make_pair(
|
||||
function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3), trace);
|
||||
p = static_cast<TracePtr>(trace);
|
||||
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
|
||||
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
|
||||
A3 a3 = this->expressionA3_->traceExecution(values,trace->trace3);
|
||||
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
|
||||
}
|
||||
|
||||
};
|
||||
|
|
|
@ -117,9 +117,8 @@ public:
|
|||
Augmented<T> augmented(const Values& values) const {
|
||||
#define REVERSE_AD
|
||||
#ifdef REVERSE_AD
|
||||
T value;
|
||||
TracePtr trace;
|
||||
boost::tie(value,trace) = root_->traceExecution(values);
|
||||
T value = root_->traceExecution(values,trace);
|
||||
Augmented<T> augmented(value);
|
||||
trace->reverseAD(augmented.jacobians());
|
||||
delete trace;
|
||||
|
|
Loading…
Reference in New Issue