Unary Trace done

release/4.3a0
dellaert 2014-10-05 13:33:23 +02:00
parent 8e527a2251
commit 75445307b2
1 changed files with 32 additions and 3 deletions

View File

@ -69,6 +69,11 @@ public:
jacobians_[key] = Eigen::MatrixXd::Identity(n, n);
}
/// Construct from value and JacobianMap
Augmented(const T& t, const JacobianMap& jacobians) :
value_(t), jacobians_(jacobians) {
}
/// Construct value, pre-multiply jacobians by H
Augmented(const T& t, const Matrix& H, const JacobianMap& jacobians) :
value_(t) {
@ -76,7 +81,8 @@ public:
}
/// Construct from value and two disjoint JacobianMaps
Augmented(const T& t, const JacobianMap& jacobians1, const JacobianMap& jacobians2) :
Augmented(const T& t, const JacobianMap& jacobians1,
const JacobianMap& jacobians2) :
value_(t), jacobians_(jacobians1) {
jacobians_.insert(jacobians2.begin(), jacobians2.end());
}
@ -288,6 +294,29 @@ public:
return Augmented<T>(t, H, argument.jacobians());
}
/// Trace structure for reverse AD
typedef typename ExpressionNode<T>::Trace BaseTrace;
struct Trace: public BaseTrace {
boost::shared_ptr<typename ExpressionNode<A>::Trace> trace1;
Matrix H1;
T t;
/// Return value and derivatives
virtual Augmented<T> augmented(const Matrix& H) const {
// This is a top-down calculation
// The end-result needs Jacobians to all leaf nodes.
// Since this is not a leaf node, we compute what is needed for leaf nodes here
Augmented<A> augmented1 = trace1->augmented(H * H1);
return Augmented<T>(t, augmented1.jacobians());
}
};
/// Construct an execution trace for reverse AD
virtual boost::shared_ptr<BaseTrace> reverse(const Values& values) const {
boost::shared_ptr<Trace> trace = boost::make_shared<Trace>();
trace->trace1 = this->expressionA_->reverse(values);
trace->t = function_(trace->trace1->value(), trace->H1);
return trace;
}
};
//-----------------------------------------------------------------------------
@ -362,8 +391,8 @@ public:
// The end-result needs Jacobians to all leaf nodes.
// Since this is not a leaf node, we compute what is needed for leaf nodes here
// The binary node represents a fork in the tree, and hence we will get two Augmented maps
Augmented<A1> augmented1 = trace1->augmented(H*H1);
Augmented<A1> augmented2 = trace1->augmented(H*H2);
Augmented<A1> augmented1 = trace1->augmented(H * H1);
Augmented<A1> augmented2 = trace1->augmented(H * H2);
return Augmented<T>(t, augmented1.jacobians(), augmented2.jacobians());
}
};