diff --git a/gtsam_unstable/nonlinear/CallRecord.h b/gtsam_unstable/nonlinear/CallRecord.h index f1ac0b044..5a1fdadc4 100644 --- a/gtsam_unstable/nonlinear/CallRecord.h +++ b/gtsam_unstable/nonlinear/CallRecord.h @@ -36,7 +36,7 @@ class JacobianMap; * MaxVirtualStaticRows defines how many separate virtual reverseAD with specific * static rows (1..MaxVirtualStaticRows) methods will be part of the CallRecord interface. */ -const int MaxVirtualStaticRows = 4; +#define MaxVirtualStaticRows 4 namespace internal { @@ -69,65 +69,6 @@ struct ConvertToVirtualFunctionSupportedMatrixType { } }; -/** - * Recursive definition of an interface having several purely - * virtual _reverseAD(const Eigen::Matrix &, JacobianMap&) - * with Rows in 1..MaxSupportedStaticRows - */ -template -struct ReverseADInterface: ReverseADInterface { - using ReverseADInterface::_reverseAD; - virtual void _reverseAD( - const Eigen::Matrix & dFdT, - JacobianMap& jacobians) const = 0; -}; - -template -struct ReverseADInterface<0, Cols> { - virtual void _reverseAD( - const Eigen::Matrix & dFdT, - JacobianMap& jacobians) const = 0; - virtual void _reverseAD(const Matrix & dFdT, - JacobianMap& jacobians) const = 0; -}; - -/** - * ReverseADImplementor is a utility class used by CallRecordImplementor to - * implementing the recursive ReverseADInterface interface. - */ -template -struct ReverseADImplementor: ReverseADImplementor { -private: - using ReverseADImplementor::_reverseAD; - virtual void _reverseAD( - const Eigen::Matrix & dFdT, - JacobianMap& jacobians) const { - static_cast(this)->reverseAD(dFdT, jacobians); - } - friend struct internal::ReverseADImplementor; -}; - -template -struct ReverseADImplementor : virtual internal::ReverseADInterface< - MaxVirtualStaticRows, Cols> { -private: - using internal::ReverseADInterface::_reverseAD; - const Derived & derived() const { - return static_cast(*this); - } - virtual void _reverseAD( - const Eigen::Matrix & dFdT, - JacobianMap& jacobians) const { - derived().reverseAD(dFdT, jacobians); - } - virtual void _reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const { - derived().reverseAD(dFdT, jacobians); - } - friend struct internal::ReverseADImplementor; -}; - } // namespace internal /** @@ -138,9 +79,7 @@ private: * It is implemented in the function-style ExpressionNode's nested Record class below. */ template -struct CallRecord: virtual private internal::ReverseADInterface< - MaxVirtualStaticRows, Cols> { - +struct CallRecord { inline void print(const std::string& indent) const { _print(indent); } @@ -153,8 +92,11 @@ struct CallRecord: virtual private internal::ReverseADInterface< inline void reverseAD(const Eigen::MatrixBase & dFdT, JacobianMap& jacobians) const { _reverseAD( - internal::ConvertToVirtualFunctionSupportedMatrixType<(Derived::RowsAtCompileTime > MaxVirtualStaticRows)>::convert( - dFdT), jacobians); + internal::ConvertToVirtualFunctionSupportedMatrixType< + (Derived::RowsAtCompileTime > MaxVirtualStaticRows) + >::convert(dFdT), + jacobians + ); } inline void reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const { @@ -167,7 +109,36 @@ struct CallRecord: virtual private internal::ReverseADInterface< private: virtual void _print(const std::string& indent) const = 0; virtual void _startReverseAD(JacobianMap& jacobians) const = 0; - using internal::ReverseADInterface::_reverseAD; + + virtual void _reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const = 0; + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#if MaxVirtualStaticRows >= 1 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#endif +#if MaxVirtualStaticRows >= 2 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#endif +#if MaxVirtualStaticRows >= 3 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#endif +#if MaxVirtualStaticRows >= 4 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#endif +#if MaxVirtualStaticRows >= 5 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const = 0; +#endif }; namespace internal { @@ -176,8 +147,7 @@ namespace internal { * delegating to its corresponding (templated) non-virtual methods. */ template -struct CallRecordImplementor: public CallRecord, - private ReverseADImplementor { +struct CallRecordImplementor: public CallRecord { private: const Derived & derived() const { return static_cast(*this); @@ -188,7 +158,50 @@ private: virtual void _startReverseAD(JacobianMap& jacobians) const { derived().startReverseAD(jacobians); } - template friend struct ReverseADImplementor; + + virtual void _reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#if MaxVirtualStaticRows >= 1 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#endif +#if MaxVirtualStaticRows >= 2 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#endif +#if MaxVirtualStaticRows >= 3 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#endif +#if MaxVirtualStaticRows >= 4 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#endif +#if MaxVirtualStaticRows >= 5 + virtual void _reverseAD( + const Eigen::Matrix & dFdT, + JacobianMap& jacobians) const { + derived().reverseAD(dFdT, jacobians); + } +#endif }; } // namespace internal diff --git a/gtsam_unstable/nonlinear/tests/testCallRecord.cpp b/gtsam_unstable/nonlinear/tests/testCallRecord.cpp index a4561b349..f0569151b 100644 --- a/gtsam_unstable/nonlinear/tests/testCallRecord.cpp +++ b/gtsam_unstable/nonlinear/tests/testCallRecord.cpp @@ -90,7 +90,7 @@ struct Record: public internal::CallRecordImplementor { } template - friend struct internal::ReverseADImplementor; + friend struct internal::CallRecordImplementor; }; JacobianMap & NJM= *static_cast(NULL);