Unary Trace done

release/4.3a0
dellaert 2014-10-05 13:33:23 +02:00
parent 8e527a2251
commit 75445307b2
1 changed files with 32 additions and 3 deletions

View File

@ -69,6 +69,11 @@ public:
jacobians_[key] = Eigen::MatrixXd::Identity(n, n); jacobians_[key] = Eigen::MatrixXd::Identity(n, n);
} }
/// Construct from value and JacobianMap
Augmented(const T& t, const JacobianMap& jacobians) :
value_(t), jacobians_(jacobians) {
}
/// Construct value, pre-multiply jacobians by H /// Construct value, pre-multiply jacobians by H
Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) : Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) :
value_(t) { value_(t) {
@ -76,7 +81,8 @@ public:
} }
/// Construct from value and two disjoint JacobianMaps /// Construct from value and two disjoint JacobianMaps
Augmented(const T& t, const JacobianMap& jacobians1, const JacobianMap& jacobians2) : Augmented(const T& t, const JacobianMap& jacobians1,
const JacobianMap& jacobians2) :
value_(t), jacobians_(jacobians1) { value_(t), jacobians_(jacobians1) {
jacobians_.insert(jacobians2.begin(), jacobians2.end()); jacobians_.insert(jacobians2.begin(), jacobians2.end());
} }
@ -288,6 +294,29 @@ public:
return Augmented<T>(t, H, argument.jacobians()); return Augmented<T>(t, H, argument.jacobians());
} }
/// Trace structure for reverse AD
typedef typename ExpressionNode<T>::Trace BaseTrace;
struct Trace: public BaseTrace {
boost::shared_ptr<typename ExpressionNode<A>::Trace> trace1;
Matrix H1;
T t;
/// Return value and derivatives
virtual Augmented<T> augmented(const Matrix& H) const {
// This is a top-down calculation
// The end-result needs Jacobians to all leaf nodes.
// Since this is not a leaf node, we compute what is needed for leaf nodes here
Augmented<A> augmented1 = trace1->augmented(H * H1);
return Augmented<T>(t, augmented1.jacobians());
}
};
/// Construct an execution trace for reverse AD
virtual boost::shared_ptr<BaseTrace> reverse(const Values& values) const {
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
trace->trace1 = this->expressionA_->reverse(values);
trace->t = function_(trace->trace1->value(), trace->H1);
return trace;
}
}; };
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------