Now we can apply ExecutionTrace and Expression as meta-functions
parent
c11d7885e1
commit
1c1695353e
|
@ -201,7 +201,7 @@ template<class T>
|
||||||
class ExecutionTrace {
|
class ExecutionTrace {
|
||||||
enum {
|
enum {
|
||||||
Constant, Leaf, Function
|
Constant, Leaf, Function
|
||||||
} type;
|
} kind;
|
||||||
union {
|
union {
|
||||||
Key key;
|
Key key;
|
||||||
CallRecord<T::dimension>* ptr;
|
CallRecord<T::dimension>* ptr;
|
||||||
|
@ -209,25 +209,25 @@ class ExecutionTrace {
|
||||||
public:
|
public:
|
||||||
/// Pointer always starts out as a Constant
|
/// Pointer always starts out as a Constant
|
||||||
ExecutionTrace() :
|
ExecutionTrace() :
|
||||||
type(Constant) {
|
kind(Constant) {
|
||||||
}
|
}
|
||||||
/// Change pointer to a Leaf Record
|
/// Change pointer to a Leaf Record
|
||||||
void setLeaf(Key key) {
|
void setLeaf(Key key) {
|
||||||
type = Leaf;
|
kind = Leaf;
|
||||||
content.key = key;
|
content.key = key;
|
||||||
}
|
}
|
||||||
/// Take ownership of pointer to a Function Record
|
/// Take ownership of pointer to a Function Record
|
||||||
void setFunction(CallRecord<T::dimension>* record) {
|
void setFunction(CallRecord<T::dimension>* record) {
|
||||||
type = Function;
|
kind = Function;
|
||||||
content.ptr = record;
|
content.ptr = record;
|
||||||
}
|
}
|
||||||
/// Print
|
/// Print
|
||||||
void print(const std::string& indent = "") const {
|
void print(const std::string& indent = "") const {
|
||||||
if (type == Constant)
|
if (kind == Constant)
|
||||||
std::cout << indent << "Constant" << std::endl;
|
std::cout << indent << "Constant" << std::endl;
|
||||||
else if (type == Leaf)
|
else if (kind == Leaf)
|
||||||
std::cout << indent << "Leaf, key = " << content.key << std::endl;
|
std::cout << indent << "Leaf, key = " << content.key << std::endl;
|
||||||
else if (type == Function) {
|
else if (kind == Function) {
|
||||||
std::cout << indent << "Function" << std::endl;
|
std::cout << indent << "Function" << std::endl;
|
||||||
content.ptr->print(indent + " ");
|
content.ptr->print(indent + " ");
|
||||||
}
|
}
|
||||||
|
@ -235,7 +235,7 @@ public:
|
||||||
/// Return record pointer, quite unsafe, used only for testing
|
/// Return record pointer, quite unsafe, used only for testing
|
||||||
template<class Record>
|
template<class Record>
|
||||||
boost::optional<Record*> record() {
|
boost::optional<Record*> record() {
|
||||||
if (type != Function)
|
if (kind != Function)
|
||||||
return boost::none;
|
return boost::none;
|
||||||
else {
|
else {
|
||||||
Record* p = dynamic_cast<Record*>(content.ptr);
|
Record* p = dynamic_cast<Record*>(content.ptr);
|
||||||
|
@ -245,38 +245,41 @@ public:
|
||||||
// *** This is the main entry point for reverseAD, called from Expression::augmented ***
|
// *** This is the main entry point for reverseAD, called from Expression::augmented ***
|
||||||
// Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function)
|
// Called only once, either inserts identity into Jacobians (Leaf) or starts AD (Function)
|
||||||
void startReverseAD(JacobianMap& jacobians) const {
|
void startReverseAD(JacobianMap& jacobians) const {
|
||||||
if (type == Leaf) {
|
if (kind == Leaf) {
|
||||||
// This branch will only be called on trivial Leaf expressions, i.e. Priors
|
// This branch will only be called on trivial Leaf expressions, i.e. Priors
|
||||||
size_t n = T::Dim();
|
size_t n = T::Dim();
|
||||||
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
jacobians[content.key] = Eigen::MatrixXd::Identity(n, n);
|
||||||
} else if (type == Function)
|
} else if (kind == Function)
|
||||||
// This is the more typical entry point, starting the AD pipeline
|
// This is the more typical entry point, starting the AD pipeline
|
||||||
// It is inside the startReverseAD that the correctly dimensioned pipeline is chosen.
|
// It is inside the startReverseAD that the correctly dimensioned pipeline is chosen.
|
||||||
content.ptr->startReverseAD(jacobians);
|
content.ptr->startReverseAD(jacobians);
|
||||||
}
|
}
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
// Either add to Jacobians (Leaf) or propagate (Function)
|
||||||
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
void reverseAD(const Matrix& dTdA, JacobianMap& jacobians) const {
|
||||||
if (type == Leaf) {
|
if (kind == Leaf) {
|
||||||
JacobianMap::iterator it = jacobians.find(content.key);
|
JacobianMap::iterator it = jacobians.find(content.key);
|
||||||
if (it != jacobians.end())
|
if (it != jacobians.end())
|
||||||
it->second += dTdA;
|
it->second += dTdA;
|
||||||
else
|
else
|
||||||
jacobians[content.key] = dTdA;
|
jacobians[content.key] = dTdA;
|
||||||
} else if (type == Function)
|
} else if (kind == Function)
|
||||||
content.ptr->reverseAD(dTdA, jacobians);
|
content.ptr->reverseAD(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
// Either add to Jacobians (Leaf) or propagate (Function)
|
// Either add to Jacobians (Leaf) or propagate (Function)
|
||||||
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
typedef Eigen::Matrix<double, 2, T::dimension> Jacobian2T;
|
||||||
void reverseAD2(const Jacobian2T& dTdA, JacobianMap& jacobians) const {
|
void reverseAD2(const Jacobian2T& dTdA, JacobianMap& jacobians) const {
|
||||||
if (type == Leaf) {
|
if (kind == Leaf) {
|
||||||
JacobianMap::iterator it = jacobians.find(content.key);
|
JacobianMap::iterator it = jacobians.find(content.key);
|
||||||
if (it != jacobians.end())
|
if (it != jacobians.end())
|
||||||
it->second += dTdA;
|
it->second += dTdA;
|
||||||
else
|
else
|
||||||
jacobians[content.key] = dTdA;
|
jacobians[content.key] = dTdA;
|
||||||
} else if (type == Function)
|
} else if (kind == Function)
|
||||||
content.ptr->reverseAD2(dTdA, jacobians);
|
content.ptr->reverseAD2(dTdA, jacobians);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Define type so we can apply it as a meta-function
|
||||||
|
typedef ExecutionTrace<T> type;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Primary template calls the generic Matrix reverseAD pipeline
|
/// Primary template calls the generic Matrix reverseAD pipeline
|
||||||
|
|
|
@ -147,6 +147,8 @@ public:
|
||||||
return root_;
|
return root_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Define type so we can apply it as a meta-function
|
||||||
|
typedef Expression<T> type;
|
||||||
};
|
};
|
||||||
|
|
||||||
// http://stackoverflow.com/questions/16260445/boost-bind-to-operator
|
// http://stackoverflow.com/questions/16260445/boost-bind-to-operator
|
||||||
|
|
|
@ -429,13 +429,35 @@ TEST(ExpressionFactor, composeTernary) {
|
||||||
namespace mpl = boost::mpl;
|
namespace mpl = boost::mpl;
|
||||||
|
|
||||||
#include <boost/mpl/assert.hpp>
|
#include <boost/mpl/assert.hpp>
|
||||||
|
#include <boost/mpl/transform.hpp>
|
||||||
|
#include <boost/mpl/equal.hpp>
|
||||||
template<class T> struct Incomplete;
|
template<class T> struct Incomplete;
|
||||||
|
|
||||||
typedef mpl::vector<Pose3, Point3, Cal3_S2> MyTypes;
|
// Check generation of FunctionalNode
|
||||||
|
typedef mpl::vector<Pose3, Point3> MyTypes;
|
||||||
typedef FunctionalNode<Point2, MyTypes>::type Generated;
|
typedef FunctionalNode<Point2, MyTypes>::type Generated;
|
||||||
//Incomplete<Generated> incomplete;
|
//Incomplete<Generated> incomplete;
|
||||||
BOOST_MPL_ASSERT((boost::is_same< Matrix2, Generated::Record::Jacobian2T >));
|
BOOST_MPL_ASSERT((boost::is_same< Matrix2, Generated::Record::Jacobian2T >));
|
||||||
|
|
||||||
|
// Try generating vectors of ExecutionTrace
|
||||||
|
typedef mpl::vector<ExecutionTrace<Pose3>, ExecutionTrace<Point3> > ExpectedTraces;
|
||||||
|
|
||||||
|
typedef mpl::transform<MyTypes,ExecutionTrace<MPL::_1> >::type MyTraces;
|
||||||
|
BOOST_MPL_ASSERT((boost::mpl::equal< ExpectedTraces, MyTraces >));
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct MakeTrace {
|
||||||
|
typedef ExecutionTrace<T> type;
|
||||||
|
};
|
||||||
|
typedef mpl::transform<MyTypes,MakeTrace<MPL::_1> >::type MyTraces1;
|
||||||
|
BOOST_MPL_ASSERT((boost::mpl::equal< ExpectedTraces, MyTraces1 >));
|
||||||
|
|
||||||
|
// Try generating vectors of Expression types
|
||||||
|
typedef mpl::vector<Expression<Pose3>, Expression<Point3> > ExpectedExpressions;
|
||||||
|
|
||||||
|
typedef mpl::transform<MyTypes,Expression<MPL::_1> >::type Expressions;
|
||||||
|
BOOST_MPL_ASSERT((boost::mpl::equal< ExpectedExpressions, Expressions >));
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue