diff --git a/gtsam_unstable/nonlinear/tests/testExpression.cpp b/gtsam_unstable/nonlinear/tests/testExpression.cpp index ed17d11f8..e04f25ed2 100644 --- a/gtsam_unstable/nonlinear/tests/testExpression.cpp +++ b/gtsam_unstable/nonlinear/tests/testExpression.cpp @@ -493,25 +493,59 @@ TEST(Expression, AutoDiff2) { /* ************************************************************************* */ // Adapt ceres-style autodiff -template -struct AutoDiff: Expression { - typedef boost::function Function; - AutoDiff(Function f, const Expression& e1, const Expression& e2) : - Expression(3) { +template +struct AutoDiff { + static const int N = dimension::value; + static const int M1 = dimension::value; + static const int M2 = dimension::value; + + typedef Eigen::Matrix JacobianTA1; + typedef Eigen::Matrix JacobianTA2; + + Point2 operator()(const A1& a1, const A2& a2, + boost::optional H1, boost::optional H2) { + + // Instantiate function + F f; + + // Make arguments + Vector9 P; // zero rotation, (0,5,0) translation, focal length 1 + P << 0, 0, 0, 0, 5, 0, 1, 0, 0; + Vector3 X(10, 0, -5); // negative Z-axis convention of Snavely! + + bool success; + Vector2 result; + + if (H1 || H2) { + + // Get derivatives with AutoDiff + double *parameters[] = { P.data(), X.data() }; + double *jacobians[] = { H1->data(), H2->data() }; + success = ceres::internal::AutoDiff::Differentiate(f, + parameters, 2, result.data(), jacobians); + + } else { + // Apply the mapping, to get result + success = f(P.data(), X.data(), result.data()); + } + return Point2(); } }; -TEST(Expression, SnavelyKeys) { +TEST(Expression, Snavely) { // The DefaultChart of Camera below is laid out like Snavely's 9-dim vector typedef PinholeCamera Camera; + Expression P(1); Expression X(2); - Expression expression = // - AutoDiff(SnavelyProjection(), P, X); +// AutoDiff f; + Expression expression( + AutoDiff(), P, X); set expected = list_of(1)(2); EXPECT(expected == expression.keys()); } + /* ************************************************************************* */ int main() { TestResult tr;