Avoid argument temporaries

release/4.3a0
dellaert 2014-10-07 19:35:44 +02:00
parent 84987aa351
commit c4a92acde1
2 changed files with 23 additions and 26 deletions

View File

@ -185,7 +185,7 @@ public:
virtual Augmented<T> forward(const Values& values) const = 0;
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const = 0;
virtual T traceExecution(const Values& values, TracePtr& p) const = 0;
};
//-----------------------------------------------------------------------------
@ -236,9 +236,10 @@ public:
};
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
return std::make_pair(constant_, trace);
p = static_cast<TracePtr>(trace);
return constant_;
}
};
@ -299,10 +300,11 @@ public:
};
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
p = static_cast<TracePtr>(trace);
trace->key = key_;
return std::make_pair(values.at<T>(key_), trace);
return values.at<T>(key_);
}
};
@ -373,11 +375,11 @@ public:
};
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
A a;
virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
boost::tie(a, trace->trace) = this->expressionA_->traceExecution(values);
return std::make_pair(function_(a, trace->dTdA), trace);
p = static_cast<TracePtr>(trace);
A a = this->expressionA_->traceExecution(values,trace->trace);
return function_(a, trace->dTdA);
}
};
@ -465,13 +467,12 @@ public:
};
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
A1 a1;
A2 a2;
virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
return std::make_pair(function_(a1, a2, trace->dTdA1, trace->dTdA2), trace);
p = static_cast<TracePtr>(trace);
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
return function_(a1, a2, trace->dTdA1, trace->dTdA2);
}
};
@ -576,16 +577,13 @@ public:
};
/// Construct an execution trace for reverse AD
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
A1 a1;
A2 a2;
A3 a3;
virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
boost::tie(a3, trace->trace3) = this->expressionA3_->traceExecution(values);
return std::make_pair(
function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3), trace);
p = static_cast<TracePtr>(trace);
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
A3 a3 = this->expressionA3_->traceExecution(values,trace->trace3);
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
}
};

View File

@ -117,9 +117,8 @@ public:
Augmented<T> augmented(const Values& values) const {
#define REVERSE_AD
#ifdef REVERSE_AD
T value;
TracePtr trace;
boost::tie(value,trace) = root_->traceExecution(values);
T value = root_->traceExecution(values,trace);
Augmented<T> augmented(value);
trace->reverseAD(augmented.jacobians());
delete trace;