Re-factor, allow traceExecution
parent
1f692638f5
commit
05f78b6dca
|
@ -113,18 +113,33 @@ public:
|
||||||
return root_->value(values);
|
return root_->value(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return value and derivatives
|
/// Return value and derivatives, forward AD version
|
||||||
Augmented<T> augmented(const Values& values) const {
|
Augmented<T> forward(const Values& values) const {
|
||||||
#define REVERSE_AD
|
return root_->forward(values);
|
||||||
#ifdef REVERSE_AD
|
}
|
||||||
|
|
||||||
|
/// trace execution, very unsafe, for testing purposes only
|
||||||
|
T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
|
void* raw) const {
|
||||||
|
return root_->traceExecution(values, trace, raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return value and derivatives, reverse AD version
|
||||||
|
Augmented<T> reverse(const Values& values) const {
|
||||||
char raw[10];
|
char raw[10];
|
||||||
ExecutionTrace<T> trace;
|
ExecutionTrace<T> trace;
|
||||||
T value(root_->traceExecution(values, trace, raw));
|
T value(root_->traceExecution(values, trace, raw));
|
||||||
Augmented<T> augmented(value);
|
Augmented<T> augmented(value);
|
||||||
trace.startReverseAD(augmented.jacobians());
|
trace.startReverseAD(augmented.jacobians());
|
||||||
return augmented;
|
return augmented;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return value and derivatives
|
||||||
|
Augmented<T> augmented(const Values& values) const {
|
||||||
|
#ifdef EXPRESSION_FORWARD_AD
|
||||||
|
return forward(values);
|
||||||
#else
|
#else
|
||||||
return root_->forward(values);
|
return reverse(values);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue