Documentation
parent
4ac065fab4
commit
6a1bc6e242
|
@ -32,7 +32,15 @@ class Expression;
|
||||||
typedef std::map<Key, Matrix> JacobianMap;
|
typedef std::map<Key, Matrix> JacobianMap;
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
/// The JacobinaTrace class records a tree-structured expression's execution
|
||||||
struct JacobianTrace {
|
struct JacobianTrace {
|
||||||
|
/**
|
||||||
|
* The Pointer class is a tagged union that obviates the need to create
|
||||||
|
* a JacobianTrace subclass for Constants and Leaf Expressions. Instead
|
||||||
|
* the key for the leaf is stored in the space normally used to store a
|
||||||
|
* JacobianTrace*. Nothing is stored for a Constant.
|
||||||
|
*/
|
||||||
|
///
|
||||||
class Pointer {
|
class Pointer {
|
||||||
enum {
|
enum {
|
||||||
Constant, Leaf, Function
|
Constant, Leaf, Function
|
||||||
|
@ -46,41 +54,43 @@ struct JacobianTrace {
|
||||||
Pointer() :
|
Pointer() :
|
||||||
type(Constant) {
|
type(Constant) {
|
||||||
}
|
}
|
||||||
|
/// Destructor cleans up pointer if Function
|
||||||
~Pointer() {
|
~Pointer() {
|
||||||
if (type == Function)
|
if (type == Function)
|
||||||
delete content.ptr;
|
delete content.ptr;
|
||||||
}
|
}
|
||||||
|
/// Change pointer to a Leaf Trace
|
||||||
void setLeaf(Key key) {
|
void setLeaf(Key key) {
|
||||||
type = Leaf;
|
type = Leaf;
|
||||||
content.key = key;
|
content.key = key;
|
||||||
}
|
}
|
||||||
|
/// Take ownership of pointer to a Function Trace
|
||||||
void setFunction(JacobianTrace* trace) {
|
void setFunction(JacobianTrace* trace) {
|
||||||
type = Function;
|
type = Function;
|
||||||
content.ptr = trace;
|
content.ptr = trace;
|
||||||
}
|
}
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
// Either insert identity into Jacobians (Leaf) or propagate (Function)
|
||||||
template<class T>
|
template<class T>
|
||||||
void reverseAD(JacobianMap& jacobians) const {
|
void reverseAD(JacobianMap& jacobians) const {
|
||||||
if (type == Function)
|
if (type == Leaf) {
|
||||||
content.ptr->reverseAD(jacobians);
|
|
||||||
else if (type == Leaf) {
|
|
||||||
size_t n = T::Dim();
|
size_t n = T::Dim();
|
||||||
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
||||||
}
|
} else if (type == Function)
|
||||||
|
content.ptr->reverseAD(jacobians);
|
||||||
}
|
}
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
// Either add to Jacobians (Leaf) or propagate (Function)
|
||||||
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
||||||
if (type == Function)
|
if (type == Leaf) {
|
||||||
content.ptr->reverseAD(dTdA, jacobians);
|
|
||||||
else if (type == Leaf) {
|
|
||||||
JacobianMap::iterator it = jacobians.find(content.key);
|
JacobianMap::iterator it = jacobians.find(content.key);
|
||||||
if (it != jacobians.end())
|
if (it != jacobians.end())
|
||||||
it->second += dTdA;
|
it->second += dTdA;
|
||||||
else
|
else
|
||||||
jacobians[content.key] = dTdA;
|
jacobians[content.key] = dTdA;
|
||||||
}
|
} else if (type == Function)
|
||||||
|
content.ptr->reverseAD(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
/// Make sure destructor is virtual
|
||||||
virtual ~JacobianTrace() {
|
virtual ~JacobianTrace() {
|
||||||
}
|
}
|
||||||
virtual void reverseAD(JacobianMap& jacobians) const = 0;
|
virtual void reverseAD(JacobianMap& jacobians) const = 0;
|
||||||
|
|
Loading…
Reference in New Issue