diff --git a/gtsam_unstable/base/tests/testBAD.cpp b/gtsam_unstable/base/tests/testBAD.cpp index a6e9469e2..558510970 100644 --- a/gtsam_unstable/base/tests/testBAD.cpp +++ b/gtsam_unstable/base/tests/testBAD.cpp @@ -25,6 +25,7 @@ #include #include +#include #include @@ -258,21 +259,24 @@ public: }; +/** + * Expression class that supports automatic differentiation + */ template class Expression { public: - // Initialize a constant expression + // Construct a constant expression Expression(const T& value) : root_(new ConstantExpression(value)) { } - // Initialize a leaf expression + // Construct a leaf expression Expression(const Key& key) : root_(new LeafExpression(key)) { } - /// Initialize a unary expression + /// Construct a unary expression template Expression(typename UnaryExpression::function f, const Expression& expression) { @@ -280,7 +284,7 @@ public: root_.reset(new UnaryExpression(f, expression)); } - /// Initialize a binary expression + /// Construct a binary expression template Expression(typename BinaryExpression::function f, const Expression& expression1, const Expression& expression2) { @@ -288,9 +292,30 @@ public: root_.reset(new BinaryExpression(f, expression1, expression2)); } + // http://stackoverflow.com/questions/16260445/boost-bind-to-operator + template + struct apply_product { + typedef R result_type; + template + R operator()(E1 const& x, E2 const& y) const { + return x * y; + } + }; + + /// Construct a product expression, assumes E1::operator*() exists + template + friend Expression operator*(const Expression& expression1, const Expression& expression2) { + using namespace boost; + boost::bind(apply_product,_1,_2); + return Expression(boost::bind(apply_product,_1,_2),expression1, expression2); + } + + /// Return keys that play in this expression std::set keys() const { return root_->keys(); } + + /// Return value and optional derivatives T value(const Values& values, boost::optional&> jacobians = boost::none) const { return root_->value(values, jacobians); @@ -441,6 +466,14 @@ TEST(BAD, test) { } +/* ************************************************************************* */ + +TEST(BAD, rotate) { + Expression R(1); + Expression p(2); + Expression q = R * p; +} + /* ************************************************************************* */ int main() { TestResult tr;