swtichted to Eigen::MatrixBase<Derived> as far as possible for dTdA to gain some performance and not to forget some types at some level :).

release/4.3a0
HannesSommer 2014-11-24 10:48:15 +01:00
parent 3bf92d1a47
commit db6c9ff378
2 changed files with 42 additions and 18 deletions

View File

@ -46,7 +46,7 @@ namespace internal {
* it just passes dense Eigen matrices through.
*/
template<bool ConvertToDynamicRows>
struct ConvertToDynamicRowsIf {
struct ConvertToVirtualFunctionSupportedMatrixType {
template<typename Derived>
static Eigen::Matrix<double, Eigen::Dynamic, Derived::ColsAtCompileTime> convert(
const Eigen::MatrixBase<Derived> & x) {
@ -55,7 +55,13 @@ struct ConvertToDynamicRowsIf {
};
template<>
struct ConvertToDynamicRowsIf<false> {
struct ConvertToVirtualFunctionSupportedMatrixType<false> {
template<typename Derived>
static const Eigen::Matrix<double, Derived::RowsAtCompileTime, Derived::ColsAtCompileTime> convert(
const Eigen::MatrixBase<Derived> & x) {
return x;
}
// special treatment of matrices that don't need conversion
template<int Rows, int Cols>
static const Eigen::Matrix<double, Rows, Cols> & convert(
const Eigen::Matrix<double, Rows, Cols> & x) {
@ -143,11 +149,11 @@ struct CallRecord: virtual private internal::ReverseADInterface<
_startReverseAD(jacobians);
}
template<int Rows>
inline void reverseAD(const Eigen::Matrix<double, Rows, Cols> & dFdT,
template<typename Derived>
inline void reverseAD(const Eigen::MatrixBase<Derived> & dFdT,
JacobianMap& jacobians) const {
_reverseAD(
internal::ConvertToDynamicRowsIf<(Rows > MaxVirtualStaticRows)>::convert(
internal::ConvertToVirtualFunctionSupportedMatrixType<(Derived::RowsAtCompileTime > MaxVirtualStaticRows)>::convert(
dFdT), jacobians);
}

View File

@ -64,18 +64,36 @@ public:
};
//-----------------------------------------------------------------------------
/// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
template<int ROWS, int COLS>
void handleLeafCase(const Eigen::Matrix<double, ROWS, COLS>& dTdA,
JacobianMap& jacobians, Key key) {
jacobians(key).block<ROWS, COLS>(0, 0) += dTdA; // block makes HUGE difference
namespace internal {
template <bool UseBlock, typename Derived>
struct UseBlockIf {
static void addToJacobian(const Eigen::MatrixBase<Derived>& dTdA,
JacobianMap& jacobians, Key key){
// block makes HUGE difference
jacobians(key).block<Derived::RowsAtCompileTime, Derived::ColsAtCompileTime>(0, 0) += dTdA;
};
};
/// Handle Leaf Case for Dynamic Matrix type (slower)
template <typename Derived>
struct UseBlockIf<false, Derived> {
static void addToJacobian(const Eigen::MatrixBase<Derived>& dTdA,
JacobianMap& jacobians, Key key) {
jacobians(key) += dTdA;
}
};
}
/// Handle Leaf Case for Dynamic ROWS Matrix type (slower)
template<int COLS>
inline void handleLeafCase(
const Eigen::Matrix<double, Eigen::Dynamic, COLS>& dTdA,
/// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
template<typename Derived>
void handleLeafCase(const Eigen::MatrixBase<Derived>& dTdA,
JacobianMap& jacobians, Key key) {
jacobians(key) += dTdA;
internal::UseBlockIf<
Derived::RowsAtCompileTime != Eigen::Dynamic &&
Derived::ColsAtCompileTime != Eigen::Dynamic,
Derived>
::addToJacobian(dTdA, jacobians, key);
}
//-----------------------------------------------------------------------------
@ -166,9 +184,9 @@ public:
void reverseAD(const Eigen::MatrixBase<DerivedMatrix> & dTdA,
JacobianMap& jacobians) const {
if (kind == Leaf)
handleLeafCase(dTdA.eval(), jacobians, content.key);
handleLeafCase(dTdA, jacobians, content.key);
else if (kind == Function)
content.ptr->reverseAD(dTdA.eval(), jacobians);
content.ptr->reverseAD(dTdA, jacobians);
}
/// Define type so we can apply it as a meta-function
@ -507,7 +525,7 @@ struct GenerateFunctionalNode: Argument<T, A, Base::N + 1>, Base {
void reverseAD(const Eigen::Matrix<double, Rows, Cols> & dFdT,
JacobianMap& jacobians) const {
Base::Record::reverseAD(dFdT, jacobians);
This::trace.reverseAD((dFdT * This::dTdA).eval(), jacobians);
This::trace.reverseAD(dFdT * This::dTdA, jacobians);
}
};