inlined a fully specialized function template defined in a .hpp
parent
a94835a2e4
commit
bd3f9db7df
|
@ -58,9 +58,9 @@ class Expression;
|
||||||
class JacobianMap {
|
class JacobianMap {
|
||||||
const FastVector<Key>& keys_;
|
const FastVector<Key>& keys_;
|
||||||
VerticalBlockMatrix& Ab_;
|
VerticalBlockMatrix& Ab_;
|
||||||
public:
|
public:
|
||||||
JacobianMap(const FastVector<Key>& keys, VerticalBlockMatrix& Ab) :
|
JacobianMap(const FastVector<Key>& keys, VerticalBlockMatrix& Ab) :
|
||||||
keys_(keys), Ab_(Ab) {
|
keys_(keys), Ab_(Ab) {
|
||||||
}
|
}
|
||||||
/// Access via key
|
/// Access via key
|
||||||
VerticalBlockMatrix::Block operator()(Key key) {
|
VerticalBlockMatrix::Block operator()(Key key) {
|
||||||
|
@ -89,7 +89,7 @@ struct CallRecord {
|
||||||
}
|
}
|
||||||
typedef Eigen::Matrix<double, 2, COLS> Jacobian2T;
|
typedef Eigen::Matrix<double, 2, COLS> Jacobian2T;
|
||||||
virtual void reverseAD2(const Jacobian2T& dFdT,
|
virtual void reverseAD2(const Jacobian2T& dFdT,
|
||||||
JacobianMap& jacobians) const {
|
JacobianMap& jacobians) const {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -97,12 +97,17 @@ struct CallRecord {
|
||||||
/// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
|
/// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
|
||||||
template<int ROWS, int COLS>
|
template<int ROWS, int COLS>
|
||||||
void handleLeafCase(const Eigen::Matrix<double, ROWS, COLS>& dTdA,
|
void handleLeafCase(const Eigen::Matrix<double, ROWS, COLS>& dTdA,
|
||||||
JacobianMap& jacobians, Key key) {
|
JacobianMap& jacobians, Key key) {
|
||||||
jacobians(key).block < ROWS, COLS > (0, 0) += dTdA; // block makes HUGE difference
|
// if (ROWS == -1 && COLS == -1 ) {
|
||||||
|
// jacobians(key) += dTdA;
|
||||||
|
// } else {
|
||||||
|
jacobians(key).block < ROWS, COLS > (0, 0) += dTdA; // block makes HUGE difference
|
||||||
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
/// Handle Leaf Case for Dynamic Matrix type (slower)
|
/// Handle Leaf Case for Dynamic Matrix type (slower)
|
||||||
template<>
|
template<>
|
||||||
void handleLeafCase(
|
inline void handleLeafCase(
|
||||||
const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& dTdA,
|
const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>& dTdA,
|
||||||
JacobianMap& jacobians, Key key) {
|
JacobianMap& jacobians, Key key) {
|
||||||
jacobians(key) += dTdA;
|
jacobians(key) += dTdA;
|
||||||
|
@ -140,10 +145,10 @@ class ExecutionTrace {
|
||||||
Key key;
|
Key key;
|
||||||
CallRecord<Dim>* ptr;
|
CallRecord<Dim>* ptr;
|
||||||
} content;
|
} content;
|
||||||
public:
|
public:
|
||||||
/// Pointer always starts out as a Constant
|
/// Pointer always starts out as a Constant
|
||||||
ExecutionTrace() :
|
ExecutionTrace() :
|
||||||
kind(Constant) {
|
kind(Constant) {
|
||||||
}
|
}
|
||||||
/// Change pointer to a Leaf Record
|
/// Change pointer to a Leaf Record
|
||||||
void setLeaf(Key key) {
|
void setLeaf(Key key) {
|
||||||
|
@ -216,7 +221,7 @@ template<size_t ROWS, class A>
|
||||||
struct Select {
|
struct Select {
|
||||||
typedef Eigen::Matrix<double, ROWS, traits::dimension<A>::value> Jacobian;
|
typedef Eigen::Matrix<double, ROWS, traits::dimension<A>::value> Jacobian;
|
||||||
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
||||||
JacobianMap& jacobians) {
|
JacobianMap& jacobians) {
|
||||||
trace.reverseAD(dTdA, jacobians);
|
trace.reverseAD(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -226,7 +231,7 @@ template<class A>
|
||||||
struct Select<2, A> {
|
struct Select<2, A> {
|
||||||
typedef Eigen::Matrix<double, 2, traits::dimension<A>::value> Jacobian;
|
typedef Eigen::Matrix<double, 2, traits::dimension<A>::value> Jacobian;
|
||||||
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
static void reverseAD(const ExecutionTrace<A>& trace, const Jacobian& dTdA,
|
||||||
JacobianMap& jacobians) {
|
JacobianMap& jacobians) {
|
||||||
trace.reverseAD2(dTdA, jacobians);
|
trace.reverseAD2(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -242,16 +247,16 @@ struct Select<2, A> {
|
||||||
template<class T>
|
template<class T>
|
||||||
class ExpressionNode {
|
class ExpressionNode {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
size_t traceSize_;
|
size_t traceSize_;
|
||||||
|
|
||||||
/// Constructor, traceSize is size of the execution trace of expression rooted here
|
/// Constructor, traceSize is size of the execution trace of expression rooted here
|
||||||
ExpressionNode(size_t traceSize = 0) :
|
ExpressionNode(size_t traceSize = 0) :
|
||||||
traceSize_(traceSize) {
|
traceSize_(traceSize) {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~ExpressionNode() {
|
virtual ~ExpressionNode() {
|
||||||
|
@ -277,7 +282,7 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const = 0;
|
char* raw) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
@ -290,12 +295,12 @@ class ConstantExpression: public ExpressionNode<T> {
|
||||||
|
|
||||||
/// Constructor with a value, yielding a constant
|
/// Constructor with a value, yielding a constant
|
||||||
ConstantExpression(const T& value) :
|
ConstantExpression(const T& value) :
|
||||||
constant_(value) {
|
constant_(value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
|
@ -304,7 +309,7 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
return constant_;
|
return constant_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -320,13 +325,13 @@ class LeafExpression: public ExpressionNode<T> {
|
||||||
|
|
||||||
/// Constructor with a single key
|
/// Constructor with a single key
|
||||||
LeafExpression(Key key) :
|
LeafExpression(Key key) :
|
||||||
key_(key) {
|
key_(key) {
|
||||||
}
|
}
|
||||||
// todo: do we need a virtual destructor here too?
|
// todo: do we need a virtual destructor here too?
|
||||||
|
|
||||||
friend class Expression<value_type> ;
|
friend class Expression<value_type> ;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
|
@ -348,7 +353,7 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual const value_type& traceExecution(const Values& values, ExecutionTrace<value_type>& trace,
|
virtual const value_type& traceExecution(const Values& values, ExecutionTrace<value_type>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
trace.setLeaf(key_);
|
trace.setLeaf(key_);
|
||||||
return dynamic_cast<const value_type&>(values.at(key_));
|
return dynamic_cast<const value_type&>(values.at(key_));
|
||||||
}
|
}
|
||||||
|
@ -366,13 +371,13 @@ class LeafExpression<T, DefaultChart<T> >: public ExpressionNode<T> {
|
||||||
|
|
||||||
/// Constructor with a single key
|
/// Constructor with a single key
|
||||||
LeafExpression(Key key) :
|
LeafExpression(Key key) :
|
||||||
key_(key) {
|
key_(key) {
|
||||||
}
|
}
|
||||||
// todo: do we need a virtual destructor here too?
|
// todo: do we need a virtual destructor here too?
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return keys that play in this expression
|
/// Return keys that play in this expression
|
||||||
virtual std::set<Key> keys() const {
|
virtual std::set<Key> keys() const {
|
||||||
|
@ -393,7 +398,7 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
trace.setLeaf(key_);
|
trace.setLeaf(key_);
|
||||||
return values.at<T>(key_);
|
return values.at<T>(key_);
|
||||||
}
|
}
|
||||||
|
@ -523,7 +528,7 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
virtual void startReverseAD(JacobianMap& jacobians) const {
|
virtual void startReverseAD(JacobianMap& jacobians) const {
|
||||||
Base::Record::startReverseAD(jacobians);
|
Base::Record::startReverseAD(jacobians);
|
||||||
Select<traits::dimension<T>::value, A>::reverseAD(This::trace, This::dTdA,
|
Select<traits::dimension<T>::value, A>::reverseAD(This::trace, This::dTdA,
|
||||||
jacobians);
|
jacobians);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
/// Given df/dT, multiply in dT/dA and continue reverse AD process
|
||||||
|
@ -535,7 +540,7 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
/// Version specialized to 2-dimensional output
|
/// Version specialized to 2-dimensional output
|
||||||
typedef Eigen::Matrix<double, 2, traits::dimension<T>::value> Jacobian2T;
|
typedef Eigen::Matrix<double, 2, traits::dimension<T>::value> Jacobian2T;
|
||||||
virtual void reverseAD2(const Jacobian2T& dFdT,
|
virtual void reverseAD2(const Jacobian2T& dFdT,
|
||||||
JacobianMap& jacobians) const {
|
JacobianMap& jacobians) const {
|
||||||
Base::Record::reverseAD2(dFdT, jacobians);
|
Base::Record::reverseAD2(dFdT, jacobians);
|
||||||
This::trace.reverseAD2(dFdT * This::dTdA, jacobians);
|
This::trace.reverseAD2(dFdT * This::dTdA, jacobians);
|
||||||
}
|
}
|
||||||
|
@ -549,7 +554,7 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
|
||||||
// Iff the expression is functional, write all Records in raw buffer
|
// Iff the expression is functional, write all Records in raw buffer
|
||||||
// Return value of type T is recorded in record->value
|
// Return value of type T is recorded in record->value
|
||||||
record->Record::This::value = This::expression->traceExecution(values,
|
record->Record::This::value = This::expression->traceExecution(values,
|
||||||
record->Record::This::trace, raw);
|
record->Record::This::trace, raw);
|
||||||
// raw is never modified by traceExecution, but if traceExecution has
|
// raw is never modified by traceExecution, but if traceExecution has
|
||||||
// written in the buffer, the next caller expects we advance the pointer
|
// written in the buffer, the next caller expects we advance the pointer
|
||||||
raw += This::expression->traceSize();
|
raw += This::expression->traceSize();
|
||||||
|
@ -623,26 +628,26 @@ struct FunctionalNode {
|
||||||
template<class T, class A1>
|
template<class T, class A1>
|
||||||
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type {
|
class UnaryExpression: public FunctionalNode<T, boost::mpl::vector<A1> >::type {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::function<T(const A1&, typename OptionalJacobian<T, A1>::type)> Function;
|
typedef boost::function<T(const A1&, typename OptionalJacobian<T, A1>::type)> Function;
|
||||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1> >::type Base;
|
typedef typename FunctionalNode<T, boost::mpl::vector<A1> >::type Base;
|
||||||
typedef typename Base::Record Record;
|
typedef typename Base::Record Record;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
|
|
||||||
/// Constructor with a unary function f, and input argument e
|
/// Constructor with a unary function f, and input argument e
|
||||||
UnaryExpression(Function f, const Expression<A1>& e1) :
|
UnaryExpression(Function f, const Expression<A1>& e1) :
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->template reset<A1, 1>(e1.root());
|
this->template reset<A1, 1>(e1.root());
|
||||||
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
|
ExpressionNode<T>::traceSize_ = sizeof(Record) + e1.traceSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
|
@ -651,13 +656,13 @@ public:
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
|
|
||||||
Record* record = Base::trace(values, raw);
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
return function_(record->template value<A1, 1>(),
|
return function_(record->template value<A1, 1>(),
|
||||||
record->template jacobian<A1, 1>());
|
record->template jacobian<A1, 1>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -667,22 +672,22 @@ public:
|
||||||
template<class T, class A1, class A2>
|
template<class T, class A1, class A2>
|
||||||
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> >::type {
|
class BinaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2> >::type {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::function<
|
typedef boost::function<
|
||||||
T(const A1&, const A2&, typename OptionalJacobian<T, A1>::type,
|
T(const A1&, const A2&, typename OptionalJacobian<T, A1>::type,
|
||||||
typename OptionalJacobian<T, A2>::type)> Function;
|
typename OptionalJacobian<T, A2>::type)> Function;
|
||||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2> >::type Base;
|
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2> >::type Base;
|
||||||
typedef typename Base::Record Record;
|
typedef typename Base::Record Record;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
|
|
||||||
/// Constructor with a ternary function f, and three input arguments
|
/// Constructor with a ternary function f, and three input arguments
|
||||||
BinaryExpression(Function f, const Expression<A1>& e1,
|
BinaryExpression(Function f, const Expression<A1>& e1,
|
||||||
const Expression<A2>& e2) :
|
const Expression<A2>& e2) :
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->template reset<A1, 1>(e1.root());
|
this->template reset<A1, 1>(e1.root());
|
||||||
this->template reset<A2, 2>(e2.root());
|
this->template reset<A2, 2>(e2.root());
|
||||||
ExpressionNode<T>::traceSize_ = //
|
ExpressionNode<T>::traceSize_ = //
|
||||||
|
@ -692,26 +697,26 @@ private:
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
friend class ::ExpressionFactorBinaryTest;
|
friend class ::ExpressionFactorBinaryTest;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
return function_(this->template expression<A1, 1>()->value(values),
|
return function_(this->template expression<A1, 1>()->value(values),
|
||||||
this->template expression<A2, 2>()->value(values),
|
this->template expression<A2, 2>()->value(values),
|
||||||
none, none);
|
none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
|
|
||||||
Record* record = Base::trace(values, raw);
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
||||||
return function_(record->template value<A1, 1>(),
|
return function_(record->template value<A1, 1>(),
|
||||||
record->template value<A2,2>(), record->template jacobian<A1, 1>(),
|
record->template value<A2,2>(), record->template jacobian<A1, 1>(),
|
||||||
record->template jacobian<A2, 2>());
|
record->template jacobian<A2, 2>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -721,22 +726,22 @@ public:
|
||||||
template<class T, class A1, class A2, class A3>
|
template<class T, class A1, class A2, class A3>
|
||||||
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
|
class TernaryExpression: public FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef boost::function<
|
typedef boost::function<
|
||||||
T(const A1&, const A2&, const A3&, typename OptionalJacobian<T, A1>::type,
|
T(const A1&, const A2&, const A3&, typename OptionalJacobian<T, A1>::type,
|
||||||
typename OptionalJacobian<T, A2>::type, typename OptionalJacobian<T, A3>::type)> Function;
|
typename OptionalJacobian<T, A2>::type, typename OptionalJacobian<T, A3>::type)> Function;
|
||||||
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base;
|
typedef typename FunctionalNode<T, boost::mpl::vector<A1, A2, A3> >::type Base;
|
||||||
typedef typename Base::Record Record;
|
typedef typename Base::Record Record;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Function function_;
|
Function function_;
|
||||||
|
|
||||||
/// Constructor with a ternary function f, and three input arguments
|
/// Constructor with a ternary function f, and three input arguments
|
||||||
TernaryExpression(Function f, const Expression<A1>& e1,
|
TernaryExpression(Function f, const Expression<A1>& e1,
|
||||||
const Expression<A2>& e2, const Expression<A3>& e3) :
|
const Expression<A2>& e2, const Expression<A3>& e3) :
|
||||||
function_(f) {
|
function_(f) {
|
||||||
this->template reset<A1, 1>(e1.root());
|
this->template reset<A1, 1>(e1.root());
|
||||||
this->template reset<A2, 2>(e2.root());
|
this->template reset<A2, 2>(e2.root());
|
||||||
this->template reset<A3, 3>(e3.root());
|
this->template reset<A3, 3>(e3.root());
|
||||||
|
@ -746,20 +751,20 @@ private:
|
||||||
|
|
||||||
friend class Expression<T> ;
|
friend class Expression<T> ;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Return value
|
/// Return value
|
||||||
virtual T value(const Values& values) const {
|
virtual T value(const Values& values) const {
|
||||||
using boost::none;
|
using boost::none;
|
||||||
return function_(this->template expression<A1, 1>()->value(values),
|
return function_(this->template expression<A1, 1>()->value(values),
|
||||||
this->template expression<A2, 2>()->value(values),
|
this->template expression<A2, 2>()->value(values),
|
||||||
this->template expression<A3, 3>()->value(values),
|
this->template expression<A3, 3>()->value(values),
|
||||||
none, none, none);
|
none, none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct an execution trace for reverse AD
|
/// Construct an execution trace for reverse AD
|
||||||
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
virtual T traceExecution(const Values& values, ExecutionTrace<T>& trace,
|
||||||
char* raw) const {
|
char* raw) const {
|
||||||
|
|
||||||
Record* record = Base::trace(values, raw);
|
Record* record = Base::trace(values, raw);
|
||||||
trace.setFunction(record);
|
trace.setFunction(record);
|
||||||
|
|
Loading…
Reference in New Issue