Tagged union, lightweight

release/4.3a0
dellaert 2014-10-08 15:39:59 +02:00
parent 390842e1f7
commit ce2dcaeb3b
2 changed files with 74 additions and 65 deletions

View File

@ -33,6 +33,54 @@ typedef std::map<Key, Matrix> JacobianMap;
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
struct JacobianTrace { struct JacobianTrace {
class Pointer {
enum {
Constant, Leaf, Function
} type;
union {
Key key;
JacobianTrace* ptr;
} content;
public:
/// Pointer always starts out as a Constant
Pointer() :
type(Constant) {
}
~Pointer() {
if (type == Function)
delete content.ptr;
}
void setLeaf(Key key) {
type = Leaf;
content.key = key;
}
void setFunction(JacobianTrace* trace) {
type = Function;
content.ptr = trace;
}
// Either add to Jacobians (Leaf) or propagate (Function)
template<class T>
void reverseAD(JacobianMap& jacobians) const {
if (type == Function)
content.ptr->reverseAD(jacobians);
else if (type == Leaf) {
size_t n = T::Dim();
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
}
}
// Either add to Jacobians (Leaf) or propagate (Function)
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
if (type == Function)
content.ptr->reverseAD(dTdA, jacobians);
else if (type == Leaf) {
JacobianMap::iterator it = jacobians.find(content.key);
if (it != jacobians.end())
it->second += dTdA;
else
jacobians[content.key] = dTdA;
}
}
};
virtual ~JacobianTrace() { virtual ~JacobianTrace() {
} }
virtual void reverseAD(JacobianMap& jacobians) const = 0; virtual void reverseAD(JacobianMap& jacobians) const = 0;
@ -41,7 +89,7 @@ struct JacobianTrace {
// void reverseAD(const JacobianFT& dFdT, JacobianMap& jacobians) const { // void reverseAD(const JacobianFT& dFdT, JacobianMap& jacobians) const {
}; };
typedef JacobianTrace* TracePtr; typedef JacobianTrace::Pointer TracePtr;
//template <class Derived> //template <class Derived>
//struct TypedTrace { //struct TypedTrace {
@ -225,20 +273,8 @@ public:
return Augmented<T>(constant_); return Augmented<T>(constant_);
} }
/// Trace structure for reverse AD
struct Trace: public JacobianTrace {
/// If the expression is just a constant, we do nothing
virtual void reverseAD(JacobianMap& jacobians) const {
}
/// Base case: we simply ignore the given df/dT
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
}
};
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, TracePtr& p) const { virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace();
p = static_cast<TracePtr>(trace);
return constant_; return constant_;
} }
}; };
@ -281,29 +317,9 @@ public:
return Augmented<T>(values.at<T>(key_), key_); return Augmented<T>(values.at<T>(key_), key_);
} }
/// Trace structure for reverse AD
struct Trace: public JacobianTrace {
Key key;
/// If the expression is just a leaf, we just insert an identity matrix
virtual void reverseAD(JacobianMap& jacobians) const {
size_t n = T::Dim();
jacobians[key] = Eigen::MatrixXd::Identity(n, n);
}
/// Base case: given df/dT, add it jacobians with correct key and we are done
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
JacobianMap::iterator it = jacobians.find(key);
if (it != jacobians.end())
it->second += dFdT;
else
jacobians[key] = dFdT;
}
};
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, TracePtr& p) const { virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace(); p.setLeaf(key_);
p = static_cast<TracePtr>(trace);
trace->key = key_;
return values.at<T>(key_); return values.at<T>(key_);
} }
@ -362,23 +378,22 @@ public:
TracePtr trace; TracePtr trace;
JacobianTA dTdA; JacobianTA dTdA;
virtual ~Trace() { virtual ~Trace() {
delete trace;
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void reverseAD(JacobianMap& jacobians) const { virtual void reverseAD(JacobianMap& jacobians) const {
trace->reverseAD(dTdA, jacobians); trace.reverseAD(dTdA, jacobians);
} }
/// Given df/dT, multiply in dT/dA and continue reverse AD process /// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
trace->reverseAD(dFdT * dTdA, jacobians); trace.reverseAD(dFdT * dTdA, jacobians);
} }
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, TracePtr& p) const { virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace(); Trace* trace = new Trace();
p = static_cast<TracePtr>(trace); p.setFunction(trace);
A a = this->expressionA_->traceExecution(values,trace->trace); A a = this->expressionA_->traceExecution(values, trace->trace);
return function_(a, trace->dTdA); return function_(a, trace->dTdA);
} }
}; };
@ -451,27 +466,25 @@ public:
JacobianTA1 dTdA1; JacobianTA1 dTdA1;
JacobianTA2 dTdA2; JacobianTA2 dTdA2;
virtual ~Trace() { virtual ~Trace() {
delete trace1;
delete trace2;
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void reverseAD(JacobianMap& jacobians) const { virtual void reverseAD(JacobianMap& jacobians) const {
trace1->reverseAD(dTdA1, jacobians); trace1.reverseAD(dTdA1, jacobians);
trace2->reverseAD(dTdA2, jacobians); trace2.reverseAD(dTdA2, jacobians);
} }
/// Given df/dT, multiply in dT/dA and continue reverse AD process /// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
trace1->reverseAD(dFdT * dTdA1, jacobians); trace1.reverseAD(dFdT * dTdA1, jacobians);
trace2->reverseAD(dFdT * dTdA2, jacobians); trace2.reverseAD(dFdT * dTdA2, jacobians);
} }
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, TracePtr& p) const { virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace(); Trace* trace = new Trace();
p = static_cast<TracePtr>(trace); p.setFunction(trace);
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1); A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1);
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2); A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2);
return function_(a1, a2, trace->dTdA1, trace->dTdA2); return function_(a1, a2, trace->dTdA1, trace->dTdA2);
} }
@ -558,31 +571,28 @@ public:
JacobianTA2 dTdA2; JacobianTA2 dTdA2;
JacobianTA3 dTdA3; JacobianTA3 dTdA3;
virtual ~Trace() { virtual ~Trace() {
delete trace1;
delete trace2;
delete trace3;
} }
/// Start the reverse AD process /// Start the reverse AD process
virtual void reverseAD(JacobianMap& jacobians) const { virtual void reverseAD(JacobianMap& jacobians) const {
trace1->reverseAD(dTdA1, jacobians); trace1.reverseAD(dTdA1, jacobians);
trace2->reverseAD(dTdA2, jacobians); trace2.reverseAD(dTdA2, jacobians);
trace3->reverseAD(dTdA3, jacobians); trace3.reverseAD(dTdA3, jacobians);
} }
/// Given df/dT, multiply in dT/dA and continue reverse AD process /// Given df/dT, multiply in dT/dA and continue reverse AD process
virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const { virtual void reverseAD(const Matrix& dFdT, JacobianMap& jacobians) const {
trace1->reverseAD(dFdT * dTdA1, jacobians); trace1.reverseAD(dFdT * dTdA1, jacobians);
trace2->reverseAD(dFdT * dTdA2, jacobians); trace2.reverseAD(dFdT * dTdA2, jacobians);
trace3->reverseAD(dFdT * dTdA3, jacobians); trace3.reverseAD(dFdT * dTdA3, jacobians);
} }
}; };
/// Construct an execution trace for reverse AD /// Construct an execution trace for reverse AD
virtual T traceExecution(const Values& values, TracePtr& p) const { virtual T traceExecution(const Values& values, TracePtr& p) const {
Trace* trace = new Trace(); Trace* trace = new Trace();
p = static_cast<TracePtr>(trace); p.setFunction(trace);
A1 a1 = this->expressionA1_->traceExecution(values,trace->trace1); A1 a1 = this->expressionA1_->traceExecution(values, trace->trace1);
A2 a2 = this->expressionA2_->traceExecution(values,trace->trace2); A2 a2 = this->expressionA2_->traceExecution(values, trace->trace2);
A3 a3 = this->expressionA3_->traceExecution(values,trace->trace3); A3 a3 = this->expressionA3_->traceExecution(values, trace->trace3);
return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3); return function_(a1, a2, a3, trace->dTdA1, trace->dTdA2, trace->dTdA3);
} }

View File

@ -117,11 +117,10 @@ public:
Augmented<T> augmented(const Values& values) const { Augmented<T> augmented(const Values& values) const {
#define REVERSE_AD #define REVERSE_AD
#ifdef REVERSE_AD #ifdef REVERSE_AD
TracePtr trace; JacobianTrace::Pointer pointer;
T value = root_->traceExecution(values,trace); T value = root_->traceExecution(values,pointer);
Augmented<T> augmented(value); Augmented<T> augmented(value);
trace->reverseAD(augmented.jacobians()); pointer.reverseAD<T>(augmented.jacobians());
delete trace;
return augmented; return augmented;
#else #else
return root_->forward(values); return root_->forward(values);