Avoid argument temporaries
parent
84987aa351
commit
c4a92acde1
|
@ -185,7 +185,7 @@ public:
|
||||||
virtual Augmented<T> forward(const Values& values) const = 0;
|
virtual Augmented<T> forward(const Values& values) const = 0;
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const = 0;
|
virtual T traceExecution(const Values& values, TracePtr& p) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
@ -236,9 +236,10 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
return std::make_pair(constant_, trace);
|
p = static_cast<TracePtr>(trace);
|
||||||
|
return constant_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -299,10 +300,11 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
|
p = static_cast<TracePtr>(trace);
|
||||||
trace->key = key_;
|
trace->key = key_;
|
||||||
return std::make_pair(values.at<T>(key_), trace);
|
return values.at<T>(key_);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -373,11 +375,11 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||||
A a;
|
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
boost::tie(a, trace->trace) = this->expressionA_->traceExecution(values);
|
p = static_cast<TracePtr>(trace);
|
||||||
return std::make_pair(function_(a, trace->dTdA), trace);
|
A a = this->expressionA_->traceExecution(values,trace->trace);
|
||||||
|
return function_(a, trace->dTdA);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -465,13 +467,12 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||||
A1 a1;
|
|
||||||
A2 a2;
|
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
p = static_cast<TracePtr>(trace);
|
||||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
|
||||||
return std::make_pair(function_(a1, a2, trace->dTdA1, trace->dTdA2), trace);
|
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
|
||||||
|
return function_(a1, a2, trace->dTdA1, trace->dTdA2);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -576,16 +577,13 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual std::pair<T, TracePtr> traceExecution(const Values& values) const {
|
virtual T traceExecution(const Values& values, TracePtr& p) const {
|
||||||
A1 a1;
|
|
||||||
A2 a2;
|
|
||||||
A3 a3;
|
|
||||||
Trace* trace = new Trace();
|
Trace* trace = new Trace();
|
||||||
boost::tie(a1, trace->trace1) = this->expressionA1_->traceExecution(values);
|
p = static_cast<TracePtr>(trace);
|
||||||
boost::tie(a2, trace->trace2) = this->expressionA2_->traceExecution(values);
|
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1);
|
||||||
boost::tie(a3, trace->trace3) = this->expressionA3_->traceExecution(values);
|
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2);
|
||||||
return std::make_pair(
|
A3 a3 = this->expressionA3_->traceExecution(values,trace->trace3);
|
||||||
function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3), trace);
|
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -117,9 +117,8 @@ public:
|
||||||
Augmented<T> augmented(const Values& values) const {
|
Augmented<T> augmented(const Values& values) const {
|
||||||
#define REVERSE_AD
|
#define REVERSE_AD
|
||||||
#ifdef REVERSE_AD
|
#ifdef REVERSE_AD
|
||||||
T value;
|
|
||||||
TracePtr trace;
|
TracePtr trace;
|
||||||
boost::tie(value,trace) = root_->traceExecution(values);
|
T value = root_->traceExecution(values,trace);
|
||||||
Augmented<T> augmented(value);
|
Augmented<T> augmented(value);
|
||||||
trace->reverseAD(augmented.jacobians());
|
trace->reverseAD(augmented.jacobians());
|
||||||
delete trace;
|
delete trace;
|
||||||
|
|
Loading…
Reference in New Issue