From 30ffcdd1371985b5415e7b22918c65ecbc42789e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 6 Jul 2020 21:48:51 -0400 Subject: [PATCH] Simplified FunctorizedFactor By adding the helper function MakeFunctorizedFactor, we now only need to provide the argument type in the template parameter list. This considerably simplifies the factor declaration, while removing the need for argument type and return type in the functor definition. Also added tests for std::function and lambda functions. --- gtsam/nonlinear/FunctorizedFactor.h | 38 ++++++++++----- .../nonlinear/tests/testFunctorizedFactor.cpp | 47 ++++++++++--------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/gtsam/nonlinear/FunctorizedFactor.h b/gtsam/nonlinear/FunctorizedFactor.h index c88579587..a83198967 100644 --- a/gtsam/nonlinear/FunctorizedFactor.h +++ b/gtsam/nonlinear/FunctorizedFactor.h @@ -25,8 +25,8 @@ namespace gtsam { /** - * Factor which evaluates functor and uses the result to compute - * error on provided measurement. + * Factor which evaluates provided unary functor and uses the result to compute + * error with respect to the provided measurement. * * Template parameters are * @param R: The return type of the functor after evaluation. @@ -40,13 +40,12 @@ namespace gtsam { * class MultiplyFunctor { * double m_; ///< simple multiplier * public: - * using argument_type = Matrix; - * using return_type = Matrix; * MultiplyFunctor(double m) : m_(m) {} * Matrix operator()(const Matrix &X, * OptionalJacobian<-1, -1> H = boost::none) const { - * if (H) *H = m_ * Matrix::Identity(X.rows()*X.cols(), - * X.rows()*X.cols()); return m_ * X; + * if (H) + * *H = m_ * Matrix::Identity(X.rows()*X.cols(), X.rows()*X.cols()); + * return m_ * X; * } * }; * @@ -72,7 +71,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1 { /** Construct with given x and the parameters of the basis * * @param key: Factor key - * @param z: Measurement object of type R + * @param z: Measurement object of same type as that returned by functor * @param model: Noise model * @param func: The instance of the functor object */ @@ -85,7 +84,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1 { /// @return a deep copy of this factor virtual NonlinearFactor::shared_ptr clone() const { return boost::static_pointer_cast( - NonlinearFactor::shared_ptr(new FunctorizedFactor(*this))); + NonlinearFactor::shared_ptr(new FunctorizedFactor(*this))); } Vector evaluateError(const T ¶ms, @@ -108,8 +107,8 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1 { } virtual bool equals(const NonlinearFactor &other, double tol = 1e-9) const { - const FunctorizedFactor *e = - dynamic_cast *>(&other); + const FunctorizedFactor *e = + dynamic_cast *>(&other); const bool base = Base::equals(*e, tol); return e && Base::equals(other, tol) && traits::Equals(this->measured_, e->measured_, tol); @@ -129,8 +128,21 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1 { }; /// traits -template -struct traits> - : public Testable> {}; +template +struct traits> + : public Testable> {}; + +/** + * Helper function to create a functorized factor. + * + * Uses function template deduction to identify return type and functor type, so + * template list only needs the functor argument type. + */ +template +FunctorizedFactor MakeFunctorizedFactor(Key key, const R &z, + const SharedNoiseModel &model, + const FUNC func) { + return FunctorizedFactor(key, z, model, func); +} } // namespace gtsam diff --git a/gtsam/nonlinear/tests/testFunctorizedFactor.cpp b/gtsam/nonlinear/tests/testFunctorizedFactor.cpp index 9ff6b8e24..12dd6b91c 100644 --- a/gtsam/nonlinear/tests/testFunctorizedFactor.cpp +++ b/gtsam/nonlinear/tests/testFunctorizedFactor.cpp @@ -34,9 +34,6 @@ class MultiplyFunctor { double m_; ///< simple multiplier public: - using argument_type = Matrix; - using return_type = Matrix; - MultiplyFunctor(double m) : m_(m) {} Matrix operator()(const Matrix &X, @@ -47,13 +44,13 @@ class MultiplyFunctor { }; /* ************************************************************************* */ +// Test identity operation for FunctorizedFactor. TEST(FunctorizedFactor, Identity) { Matrix X = Matrix::Identity(3, 3), measurement = Matrix::Identity(3, 3); double multiplier = 1.0; - - FunctorizedFactor factor(key, measurement, model, - MultiplyFunctor(multiplier)); + auto functor = MultiplyFunctor(multiplier); + auto factor = MakeFunctorizedFactor(key, measurement, model, functor); Vector error = factor.evaluateError(X); @@ -61,41 +58,45 @@ TEST(FunctorizedFactor, Identity) { } /* ************************************************************************* */ +// Test FunctorizedFactor with multiplier value of 2. TEST(FunctorizedFactor, Multiply2) { double multiplier = 2.0; Matrix X = Matrix::Identity(3, 3); Matrix measurement = multiplier * Matrix::Identity(3, 3); - FunctorizedFactor factor(key, measurement, model, - MultiplyFunctor(multiplier)); + auto factor = MakeFunctorizedFactor(key, measurement, model, + MultiplyFunctor(multiplier)); Vector error = factor.evaluateError(X); EXPECT(assert_equal(Vector::Zero(9), error, 1e-9)); } +/* ************************************************************************* */ +// Test equality function for FunctorizedFactor. TEST(FunctorizedFactor, Equality) { Matrix measurement = Matrix::Identity(2, 2); double multiplier = 2.0; - FunctorizedFactor factor1(key, measurement, model, - MultiplyFunctor(multiplier)); - FunctorizedFactor factor2(key, measurement, model, - MultiplyFunctor(multiplier)); + auto factor1 = MakeFunctorizedFactor(key, measurement, model, + MultiplyFunctor(multiplier)); + auto factor2 = MakeFunctorizedFactor(key, measurement, model, + MultiplyFunctor(multiplier)); EXPECT(factor1.equals(factor2)); } -//****************************************************************************** +/* *************************************************************************** */ +// Test Jacobians of FunctorizedFactor. TEST(FunctorizedFactor, Jacobians) { Matrix X = Matrix::Identity(3, 3); Matrix actualH; double multiplier = 2.0; - FunctorizedFactor factor(key, X, model, - MultiplyFunctor(multiplier)); + auto factor = + MakeFunctorizedFactor(key, X, model, MultiplyFunctor(multiplier)); Values values; values.insert(key, X); @@ -105,13 +106,14 @@ TEST(FunctorizedFactor, Jacobians) { } /* ************************************************************************* */ +// Test print result of FunctorizedFactor. TEST(FunctorizedFactor, Print) { Matrix X = Matrix::Identity(2, 2); double multiplier = 2.0; - FunctorizedFactor factor(key, X, model, - MultiplyFunctor(multiplier)); + auto factor = + MakeFunctorizedFactor(key, X, model, MultiplyFunctor(multiplier)); // redirect output to buffer so we can compare stringstream buffer; @@ -137,7 +139,7 @@ TEST(FunctorizedFactor, Print) { } /* ************************************************************************* */ -// Test factor using a std::function type. +// Test FunctorizedFactor using a std::function type. TEST(FunctorizedFactor, Functional) { double multiplier = 2.0; Matrix X = Matrix::Identity(3, 3); @@ -145,7 +147,8 @@ TEST(FunctorizedFactor, Functional) { std::function)> functional = MultiplyFunctor(multiplier); - FunctorizedFactor factor(key, measurement, model, functional); + auto factor = + MakeFunctorizedFactor(key, measurement, model, functional); Vector error = factor.evaluateError(X); @@ -153,6 +156,7 @@ TEST(FunctorizedFactor, Functional) { } /* ************************************************************************* */ +// Test FunctorizedFactor with a lambda function. TEST(FunctorizedFactor, Lambda) { double multiplier = 2.0; Matrix X = Matrix::Identity(3, 3); @@ -166,15 +170,14 @@ TEST(FunctorizedFactor, Lambda) { return multiplier * X; }; // FunctorizedFactor factor(key, measurement, model, lambda); - auto factor = FunctorizedFactor(key, measurement, model, lambda); + auto factor = MakeFunctorizedFactor(key, measurement, model, lambda); Vector error = factor.evaluateError(X); EXPECT(assert_equal(Vector::Zero(9), error, 1e-9)); } -/* ************************************************************************* - */ +/* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);