Unary Trace done
parent
8e527a2251
commit
75445307b2
|
@ -69,6 +69,11 @@ public:
|
||||||
jacobians_[key] = Eigen::MatrixXd::Identity(n, n);
|
jacobians_[key] = Eigen::MatrixXd::Identity(n, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct from value and JacobianMap
|
||||||
|
Augmented(const T& t, const JacobianMap& jacobians) :
|
||||||
|
value_(t), jacobians_(jacobians) {
|
||||||
|
}
|
||||||
|
|
||||||
/// Construct value, pre-multiply jacobians by H
|
/// Construct value, pre-multiply jacobians by H
|
||||||
Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) :
|
Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) :
|
||||||
value_(t) {
|
value_(t) {
|
||||||
|
@ -76,7 +81,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from value and two disjoint JacobianMaps
|
/// Construct from value and two disjoint JacobianMaps
|
||||||
Augmented(const T& t, const JacobianMap& jacobians1, const JacobianMap& jacobians2) :
|
Augmented(const T& t, const JacobianMap& jacobians1,
|
||||||
|
const JacobianMap& jacobians2) :
|
||||||
value_(t), jacobians_(jacobians1) {
|
value_(t), jacobians_(jacobians1) {
|
||||||
jacobians_.insert(jacobians2.begin(), jacobians2.end());
|
jacobians_.insert(jacobians2.begin(), jacobians2.end());
|
||||||
}
|
}
|
||||||
|
@ -288,6 +294,29 @@ public:
|
||||||
return Augmented<T>(t, H, argument.jacobians());
|
return Augmented<T>(t, H, argument.jacobians());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trace structure for reverse AD
|
||||||
|
typedef typename ExpressionNode<T>::Trace BaseTrace;
|
||||||
|
struct Trace: public BaseTrace {
|
||||||
|
boost::shared_ptr<typename ExpressionNode<A>::Trace> trace1;
|
||||||
|
Matrix H1;
|
||||||
|
T t;
|
||||||
|
/// Return value and derivatives
|
||||||
|
virtual Augmented<T> augmented(const Matrix& H) const {
|
||||||
|
// This is a top-down calculation
|
||||||
|
// The end-result needs Jacobians to all leaf nodes.
|
||||||
|
// Since this is not a leaf node, we compute what is needed for leaf nodes here
|
||||||
|
Augmented<A> augmented1 = trace1->augmented(H * H1);
|
||||||
|
return Augmented<T>(t, augmented1.jacobians());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Construct an execution trace for reverse AD
|
||||||
|
virtual boost::shared_ptr<BaseTrace> reverse(const Values& values) const {
|
||||||
|
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
|
||||||
|
trace->trace1 = this->expressionA_->reverse(values);
|
||||||
|
trace->t = function_(trace->trace1->value(), trace->H1);
|
||||||
|
return trace;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue