New and consistent naming: ExecutionTrace = whole tree, CallRecord = local information left by the function.
parent
5cfe761f27
commit
23485a0e71
|
|
@ -32,99 +32,106 @@ class Expression;
|
||||||
typedef std::map<Key, Matrix> JacobianMap;
|
typedef std::map<Key, Matrix> JacobianMap;
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
/// The JacobinaTrace class records a tree-structured expression's execution
|
/**
|
||||||
|
* The CallRecord class stores the Jacobians of applying a function
|
||||||
|
* with respect to each of its arguments. It also stores an executation trace
|
||||||
|
* (defined below) for each of its arguments.
|
||||||
|
*
|
||||||
|
* It is sub-classed in the function-style ExpressionNode sub-classes below.
|
||||||
|
*/
|
||||||
template<class T>
|
template<class T>
|
||||||
struct JacobianTrace {
|
struct CallRecord {
|
||||||
|
|
||||||
// Some fixed matrix sizes we need
|
|
||||||
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The Pointer class is a tagged union that obviates the need to create
|
|
||||||
* a JacobianTrace subclass for Constants and Leaf Expressions. Instead
|
|
||||||
* the key for the leaf is stored in the space normally used to store a
|
|
||||||
* JacobianTrace*. Nothing is stored for a Constant.
|
|
||||||
*/
|
|
||||||
///
|
|
||||||
class Pointer {
|
|
||||||
enum {
|
|
||||||
Constant, Leaf, Function
|
|
||||||
} type;
|
|
||||||
union {
|
|
||||||
Key key;
|
|
||||||
JacobianTrace* ptr;
|
|
||||||
} content;
|
|
||||||
public:
|
|
||||||
/// Pointer always starts out as a Constant
|
|
||||||
Pointer() :
|
|
||||||
type(Constant) {
|
|
||||||
}
|
|
||||||
/// Destructor cleans up pointer if Function
|
|
||||||
~Pointer() {
|
|
||||||
if (type == Function)
|
|
||||||
delete content.ptr;
|
|
||||||
}
|
|
||||||
/// Change pointer to a Leaf Trace
|
|
||||||
void setLeaf(Key key) {
|
|
||||||
type = Leaf;
|
|
||||||
content.key = key;
|
|
||||||
}
|
|
||||||
/// Take ownership of pointer to a Function Trace
|
|
||||||
void setFunction(JacobianTrace* trace) {
|
|
||||||
type = Function;
|
|
||||||
content.ptr = trace;
|
|
||||||
}
|
|
||||||
// *** This is the main entry point for reverseAD, called from Expression::augmented ***
|
|
||||||
// Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function)
|
|
||||||
void startReverseAD(JacobianMap& jacobians) const {
|
|
||||||
if (type == Leaf) {
|
|
||||||
// This branch will only be called on trivial Leaf expressions, i.e. Priors
|
|
||||||
size_t n = T::Dim();
|
|
||||||
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
|
||||||
} else if (type == Function)
|
|
||||||
// This is the more typical entry point, starting the AD pipeline
|
|
||||||
// It is inside the startReverseAD that the correctly dimensioned pipeline is chosen.
|
|
||||||
content.ptr->startReverseAD(jacobians);
|
|
||||||
}
|
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
|
||||||
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
|
||||||
if (type == Leaf) {
|
|
||||||
JacobianMap::iterator it = jacobians.find(content.key);
|
|
||||||
if (it != jacobians.end())
|
|
||||||
it->second += dTdA;
|
|
||||||
else
|
|
||||||
jacobians[content.key] = dTdA;
|
|
||||||
} else if (type == Function)
|
|
||||||
content.ptr->reverseAD(dTdA, jacobians);
|
|
||||||
}
|
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
|
||||||
void reverseAD2(const Jacobian2T& dTdA, JacobianMap& jacobians) const {
|
|
||||||
if (type == Leaf) {
|
|
||||||
JacobianMap::iterator it = jacobians.find(content.key);
|
|
||||||
if (it != jacobians.end())
|
|
||||||
it->second += dTdA;
|
|
||||||
else
|
|
||||||
jacobians[content.key] = dTdA;
|
|
||||||
} else if (type == Function)
|
|
||||||
content.ptr->reverseAD2(dTdA, jacobians);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Make sure destructor is virtual
|
/// Make sure destructor is virtual
|
||||||
virtual ~JacobianTrace() {
|
virtual ~CallRecord() {
|
||||||
}
|
}
|
||||||
virtual void startReverseAD(JacobianMap& jacobians) const = 0;
|
virtual void startReverseAD(JacobianMap& jacobians) const = 0;
|
||||||
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const = 0;
|
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const = 0;
|
||||||
|
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
||||||
virtual void reverseAD2(const Jacobian2T& dFdT,
|
virtual void reverseAD2(const Jacobian2T& dFdT,
|
||||||
JacobianMap& jacobians) const = 0;
|
JacobianMap& jacobians) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//-----------------------------------------------------------------------------
|
||||||
|
/**
|
||||||
|
* The ExecutionTrace class records a tree-structured expression's execution
|
||||||
|
* It is a tagged union that obviates the need to create
|
||||||
|
* a ExecutionTrace subclass for Constants and Leaf Expressions. Instead
|
||||||
|
* the key for the leaf is stored in the space normally used to store a
|
||||||
|
* CallRecord*. Nothing is stored for a Constant.
|
||||||
|
*/
|
||||||
|
template<class T>
|
||||||
|
class ExecutionTrace {
|
||||||
|
enum {
|
||||||
|
Constant, Leaf, Function
|
||||||
|
} type;
|
||||||
|
union {
|
||||||
|
Key key;
|
||||||
|
CallRecord<T>* ptr;
|
||||||
|
} content;
|
||||||
|
public:
|
||||||
|
/// Pointer always starts out as a Constant
|
||||||
|
ExecutionTrace() :
|
||||||
|
type(Constant) {
|
||||||
|
}
|
||||||
|
/// Destructor cleans up pointer if Function
|
||||||
|
~ExecutionTrace() {
|
||||||
|
if (type == Function)
|
||||||
|
delete content.ptr;
|
||||||
|
}
|
||||||
|
/// Change pointer to a Leaf Record
|
||||||
|
void setLeaf(Key key) {
|
||||||
|
type = Leaf;
|
||||||
|
content.key = key;
|
||||||
|
}
|
||||||
|
/// Take ownership of pointer to a Function Record
|
||||||
|
void setFunction(CallRecord<T>* record) {
|
||||||
|
type = Function;
|
||||||
|
content.ptr = record;
|
||||||
|
}
|
||||||
|
// *** This is the main entry point for reverseAD, called from Expression::augmented ***
|
||||||
|
// Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function)
|
||||||
|
void startReverseAD(JacobianMap& jacobians) const {
|
||||||
|
if (type == Leaf) {
|
||||||
|
// This branch will only be called on trivial Leaf expressions, i.e. Priors
|
||||||
|
size_t n = T::Dim();
|
||||||
|
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
||||||
|
} else if (type == Function)
|
||||||
|
// This is the more typical entry point, starting the AD pipeline
|
||||||
|
// It is inside the startReverseAD that the correctly dimensioned pipeline is chosen.
|
||||||
|
content.ptr->startReverseAD(jacobians);
|
||||||
|
}
|
||||||
|
// Either add to Jacobians (Leaf) or propagate (Function)
|
||||||
|
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
||||||
|
if (type == Leaf) {
|
||||||
|
JacobianMap::iterator it = jacobians.find(content.key);
|
||||||
|
if (it != jacobians.end())
|
||||||
|
it->second += dTdA;
|
||||||
|
else
|
||||||
|
jacobians[content.key] = dTdA;
|
||||||
|
} else if (type == Function)
|
||||||
|
content.ptr->reverseAD(dTdA, jacobians);
|
||||||
|
}
|
||||||
|
// Either add to Jacobians (Leaf) or propagate (Function)
|
||||||
|
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
||||||
|
void reverseAD2(const Jacobian2T& dTdA, JacobianMap& jacobians) const {
|
||||||
|
if (type == Leaf) {
|
||||||
|
JacobianMap::iterator it = jacobians.find(content.key);
|
||||||
|
if (it != jacobians.end())
|
||||||
|
it->second += dTdA;
|
||||||
|
else
|
||||||
|
jacobians[content.key] = dTdA;
|
||||||
|
} else if (type == Function)
|
||||||
|
content.ptr->reverseAD2(dTdA, jacobians);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Primary template calls the generic Matrix reverseAD pipeline
|
/// Primary template calls the generic Matrix reverseAD pipeline
|
||||||
template<size_t M, class A>
|
template<size_t M, class A>
|
||||||
struct Select {
|
struct Select {
|
||||||
typedef Eigen::Matrix<double, M, A::dimension> Jacobian;
|
typedef Eigen::Matrix<double, M, A::dimension> Jacobian;
|
||||||
static void reverseAD(const typename JacobianTrace<A>::Pointer& trace,
|
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
||||||
const Jacobian& dTdA, JacobianMap& jacobians) {
|
JacobianMap& jacobians) {
|
||||||
trace.reverseAD(dTdA, jacobians);
|
trace.reverseAD(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -133,8 +140,8 @@ struct Select {
|
||||||
template<class A>
|
template<class A>
|
||||||
struct Select<2, A> {
|
struct Select<2, A> {
|
||||||
typedef Eigen::Matrix<double, 2, A::dimension> Jacobian;
|
typedef Eigen::Matrix<double, 2, A::dimension> Jacobian;
|
||||||
static void reverseAD(const typename JacobianTrace<A>::Pointer& trace,
|
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
||||||
const Jacobian& dTdA, JacobianMap& jacobians) {
|
JacobianMap& jacobians) {
|
||||||
trace.reverseAD2(dTdA, jacobians);
|
trace.reverseAD2(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -292,7 +299,7 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values,
|
||||||
typename JacobianTrace<T>::Pointer& p) const = 0;
|
ExecutionTrace<T>& p) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|
@ -329,8 +336,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& p) const {
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
|
||||||
return constant_;
|
return constant_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -370,8 +376,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& p) const {
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
|
||||||
p.setLeaf(key_);
|
p.setLeaf(key_);
|
||||||
return values.at<T>(key_);
|
return values.at<T>(key_);
|
||||||
}
|
}
|
||||||
|
|
@ -422,9 +427,9 @@ public:
|
||||||
return Augmented<T>(t, dTdA, argument.jacobians());
|
return Augmented<T>(t, dTdA, argument.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trace structure for reverse AD
|
/// Record structure for reverse AD
|
||||||
struct Trace: public JacobianTrace<T> {
|
struct Record: public CallRecord<T> {
|
||||||
typename JacobianTrace<A1>::Pointer trace1;
|
ExecutionTrace<A1> trace1;
|
||||||
JacobianTA dTdA1;
|
JacobianTA dTdA1;
|
||||||
|
|
||||||
/// Start the reverse AD process
|
/// Start the reverse AD process
|
||||||
|
|
@ -444,12 +449,11 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& p) const {
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
Record* record = new Record();
|
||||||
Trace* trace = new Trace();
|
p.setFunction(record);
|
||||||
p.setFunction(trace);
|
A1 a = this->expressionA1_->traceExecution(values, record->trace1);
|
||||||
A1 a = this->expressionA1_->traceExecution(values, trace->trace1);
|
return function_(a, record->dTdA1);
|
||||||
return function_(a, trace->dTdA1);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -511,10 +515,10 @@ public:
|
||||||
return Augmented<T>(t, dTdA1, a1.jacobians(), dTdA2, a2.jacobians());
|
return Augmented<T>(t, dTdA1, a1.jacobians(), dTdA2, a2.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trace structure for reverse AD
|
/// Record structure for reverse AD
|
||||||
struct Trace: public JacobianTrace<T> {
|
struct Record: public CallRecord<T> {
|
||||||
typename JacobianTrace<A1>::Pointer trace1;
|
ExecutionTrace<A1> trace1;
|
||||||
typename JacobianTrace<A2>::Pointer trace2;
|
ExecutionTrace<A2> trace2;
|
||||||
JacobianTA1 dTdA1;
|
JacobianTA1 dTdA1;
|
||||||
JacobianTA2 dTdA2;
|
JacobianTA2 dTdA2;
|
||||||
|
|
||||||
|
|
@ -538,13 +542,12 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& p) const {
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
Record* record = new Record();
|
||||||
Trace* trace = new Trace();
|
p.setFunction(record);
|
||||||
p.setFunction(trace);
|
A1 a1 = this->expressionA1_->traceExecution(values, record->trace1);
|
||||||
A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1);
|
A2 a2 = this->expressionA2_->traceExecution(values, record->trace2);
|
||||||
A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2);
|
return function_(a1, a2, record->dTdA1, record->dTdA2);
|
||||||
return function_(a1, a2, trace->dTdA1, trace->dTdA2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
@ -619,11 +622,11 @@ public:
|
||||||
a3.jacobians());
|
a3.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trace structure for reverse AD
|
/// Record structure for reverse AD
|
||||||
struct Trace: public JacobianTrace<T> {
|
struct Record: public CallRecord<T> {
|
||||||
typename JacobianTrace<A1>::Pointer trace1;
|
ExecutionTrace<A1> trace1;
|
||||||
typename JacobianTrace<A2>::Pointer trace2;
|
ExecutionTrace<A2> trace2;
|
||||||
typename JacobianTrace<A3>::Pointer trace3;
|
ExecutionTrace<A3> trace3;
|
||||||
JacobianTA1 dTdA1;
|
JacobianTA1 dTdA1;
|
||||||
JacobianTA2 dTdA2;
|
JacobianTA2 dTdA2;
|
||||||
JacobianTA3 dTdA3;
|
JacobianTA3 dTdA3;
|
||||||
|
|
@ -651,14 +654,13 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& p) const {
|
||||||
typename JacobianTrace<T>::Pointer& p) const {
|
Record* record = new Record();
|
||||||
Trace* trace = new Trace();
|
p.setFunction(record);
|
||||||
p.setFunction(trace);
|
A1 a1 = this->expressionA1_->traceExecution(values, record->trace1);
|
||||||
A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1);
|
A2 a2 = this->expressionA2_->traceExecution(values, record->trace2);
|
||||||
A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2);
|
A3 a3 = this->expressionA3_->traceExecution(values, record->trace3);
|
||||||
A3 a3 = this->expressionA3_->traceExecution(values, trace->trace3);
|
return function_(a1, a2, a3, record->dTdA1, record->dTdA2, record->dTdA3);
|
||||||
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -117,10 +117,10 @@ public:
|
||||||
Augmented<T> augmented(const Values& values) const {
|
Augmented<T> augmented(const Values& values) const {
|
||||||
#define REVERSE_AD
|
#define REVERSE_AD
|
||||||
#ifdef REVERSE_AD
|
#ifdef REVERSE_AD
|
||||||
typename JacobianTrace<T>::Pointer pointer;
|
ExecutionTrace<T> trace;
|
||||||
T value = root_->traceExecution(values,pointer);
|
T value = root_->traceExecution(values, trace);
|
||||||
Augmented<T> augmented(value);
|
Augmented<T> augmented(value);
|
||||||
pointer.startReverseAD(augmented.jacobians());
|
trace.startReverseAD(augmented.jacobians());
|
||||||
return augmented;
|
return augmented;
|
||||||
#else
|
#else
|
||||||
return root_->forward(values);
|
return root_->forward(values);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue