Now uses CallRecord.h

release/4.3a0
dellaert 2014-11-21 15:48:29 +01:00
parent 2983cf33a6
commit c238e5852c
1 changed files with 5 additions and 119 deletions

View File

@ -19,17 +19,16 @@
#pragma once #pragma once
#include <gtsam_unstable/nonlinear/Callrecord.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <gtsam/base/Matrix.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/Manifold.h> #include <gtsam/base/Manifold.h>
#include <gtsam/base/VerticalBlockMatrix.h>
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/bind.hpp> #include <boost/bind.hpp>
// template meta-programming headers // template meta-programming headers, TODO not all needed?
#include <boost/mpl/vector.hpp> #include <boost/mpl/vector.hpp>
#include <boost/mpl/plus.hpp> #include <boost/mpl/plus.hpp>
#include <boost/mpl/front.hpp> #include <boost/mpl/front.hpp>
@ -40,8 +39,10 @@
#include <boost/mpl/transform.hpp> #include <boost/mpl/transform.hpp>
#include <boost/mpl/at.hpp> #include <boost/mpl/at.hpp>
namespace MPL = boost::mpl::placeholders; namespace MPL = boost::mpl::placeholders;
//
//#include <new> // for placement new
#include <map>
#include <new> // for placement new
class ExpressionFactorBinaryTest; class ExpressionFactorBinaryTest;
// Forward declare for testing // Forward declare for testing
@ -71,121 +72,6 @@ public:
} }
}; };
//-----------------------------------------------------------------------------
/**
* MaxVirtualStaticRows defines how many separate virtual reverseAD with specific
* static rows (1..MaxVirtualStaticRows) methods will be part of the CallRecord interface.
*/
const int MaxVirtualStaticRows = 4;
namespace internal {
/**
* ConvertToDynamicIf converts to a dense matrix with dynamic rows iff ConvertToDynamicRows (colums stay as they are) otherwise
* it just passes dense Eigen matrices through.
*/
template <bool ConvertToDynamicRows>
struct ConvertToDynamicRowsIf {
template <typename Derived>
static Eigen::Matrix<double, Eigen::Dynamic, Derived::ColsAtCompileTime> convert(const Eigen::MatrixBase<Derived> & x){
return x;
}
};
template <>
struct ConvertToDynamicRowsIf<false> {
template <int Rows, int Cols>
static const Eigen::Matrix<double, Rows, Cols> & convert(const Eigen::Matrix<double, Rows, Cols> & x){
return x;
}
};
/**
* Recursive definition of an interface having several purely
* virtual _reverseAD(const Eigen::Matrix<double, Rows, Cols> &, JacobianMap&)
* with Rows in 1..MaxSupportedStaticRows
*/
template<int MaxSupportedStaticRows, int Cols>
struct ReverseADInterface : public ReverseADInterface < MaxSupportedStaticRows - 1, Cols> {
protected:
using ReverseADInterface < MaxSupportedStaticRows - 1, Cols>::_reverseAD;
virtual void _reverseAD(const Eigen::Matrix<double, MaxSupportedStaticRows, Cols> & dFdT, JacobianMap& jacobians) const = 0;
};
template<int Cols>
struct ReverseADInterface<0, Cols> {
protected:
void _reverseAD(){} //dummy to allow the using directive in the template without failing for MaxSupportedStaticRows == 1.
};
}
/**
* The CallRecord class stores the Jacobians of applying a function
* with respect to each of its arguments. It also stores an executation trace
* (defined below) for each of its arguments.
*
* It is implemented in the function-style ExpressionNode's nested Record class below.
*/
template<int Cols>
struct CallRecord : private internal::ReverseADInterface<MaxVirtualStaticRows, Cols> {
inline void print(const std::string& indent) const {
_print(indent);
}
inline void startReverseAD(JacobianMap& jacobians) const {
_startReverseAD(jacobians);
}
template <int Rows>
inline void reverseAD(const Eigen::Matrix<double, Rows, Cols> & dFdT, JacobianMap& jacobians) const{
_reverseAD(internal::ConvertToDynamicRowsIf<(Rows > MaxVirtualStaticRows)>::convert(dFdT), jacobians);
}
virtual ~CallRecord() {
}
private:
using internal::ReverseADInterface<MaxVirtualStaticRows, Cols>::_reverseAD;
virtual void _print(const std::string& indent) const = 0;
virtual void _startReverseAD(JacobianMap& jacobians) const = 0;
virtual void _reverseAD(const Eigen::Matrix<double, Eigen::Dynamic, Cols> & dFdT, JacobianMap& jacobians) const = 0;
virtual void _reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const = 0;
};
namespace internal {
/**
* ReverseADImplementor is a utility class used by CallRecordImplementor to implementing the recursive CallRecord interface.
*/
template <typename Derived, int MaxSupportedStaticRows, int Cols>
struct ReverseADImplementor : ReverseADImplementor<Derived, MaxSupportedStaticRows - 1, Cols> {
protected:
virtual void _reverseAD(const Eigen::Matrix<double, MaxSupportedStaticRows, Cols> & dFdT, JacobianMap& jacobians) const {
static_cast<const Derived *>(this)->reverseAD(dFdT, jacobians);
}
};
template<typename Derived, int Cols>
struct ReverseADImplementor<Derived, 0, Cols> : CallRecord<Cols> {
};
/**
* The CallRecordImplementor implements the CallRecord interface for a Derived class by
* delegating to its corresponding (templated) non-virtual methods.
*/
template<typename Derived, int Cols>
struct CallRecordImplementor : public ReverseADImplementor<Derived, MaxVirtualStaticRows, Cols> {
private:
const Derived & derived() const {
return static_cast<const Derived&>(*this);
}
virtual void _print(const std::string& indent) const {
derived().print(indent);
}
virtual void _startReverseAD(JacobianMap& jacobians) const {
derived().startReverseAD(jacobians);
}
virtual void _reverseAD(const Eigen::Matrix<double, Eigen::Dynamic, Cols> & dFdT, JacobianMap& jacobians) const {
derived().reverseAD(dFdT, jacobians);
}
virtual void _reverseAD(const Matrix & dFdT, JacobianMap& jacobians) const {
derived().reverseAD(dFdT, jacobians);
}
};
}
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
/// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians /// Handle Leaf Case: reverseAD ends here, by writing a matrix into Jacobians
template<int ROWS, int COLS> template<int ROWS, int COLS>