Simplified FunctorizedFactor

By adding the helper function MakeFunctorizedFactor, we now only need to provide the argument type in the template parameter list. This considerably simplifies the factor declaration, while removing the need for argument type and return type in the functor definition.

Also added tests for std::function and lambda functions.
release/4.3a0
Varun Agrawal 2020-07-06 21:48:51 -04:00
parent 7d0e440293
commit 30ffcdd137
2 changed files with 50 additions and 35 deletions

View File

@ -25,8 +25,8 @@
namespace gtsam { namespace gtsam {
/** /**
* Factor which evaluates functor and uses the result to compute * Factor which evaluates provided unary functor and uses the result to compute
* error on provided measurement. * error with respect to the provided measurement.
* *
* Template parameters are * Template parameters are
* @param R: The return type of the functor after evaluation. * @param R: The return type of the functor after evaluation.
@ -40,13 +40,12 @@ namespace gtsam {
* class MultiplyFunctor { * class MultiplyFunctor {
* double m_; ///< simple multiplier * double m_; ///< simple multiplier
* public: * public:
* using argument_type = Matrix;
* using return_type = Matrix;
* MultiplyFunctor(double m) : m_(m) {} * MultiplyFunctor(double m) : m_(m) {}
* Matrix operator()(const Matrix &X, * Matrix operator()(const Matrix &X,
* OptionalJacobian<-1, -1> H = boost::none) const { * OptionalJacobian<-1, -1> H = boost::none) const {
* if (H) *H = m_ * Matrix::Identity(X.rows()*X.cols(), * if (H)
* X.rows()*X.cols()); return m_ * X; * *H = m_ * Matrix::Identity(X.rows()*X.cols(), X.rows()*X.cols());
* return m_ * X;
* } * }
* }; * };
* *
@ -72,7 +71,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
/** Construct with given x and the parameters of the basis /** Construct with given x and the parameters of the basis
* *
* @param key: Factor key * @param key: Factor key
* @param z: Measurement object of type R * @param z: Measurement object of same type as that returned by functor
* @param model: Noise model * @param model: Noise model
* @param func: The instance of the functor object * @param func: The instance of the functor object
*/ */
@ -85,7 +84,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
/// @return a deep copy of this factor /// @return a deep copy of this factor
virtual NonlinearFactor::shared_ptr clone() const { virtual NonlinearFactor::shared_ptr clone() const {
return boost::static_pointer_cast<NonlinearFactor>( return boost::static_pointer_cast<NonlinearFactor>(
NonlinearFactor::shared_ptr(new FunctorizedFactor<T, R>(*this))); NonlinearFactor::shared_ptr(new FunctorizedFactor<R, T>(*this)));
} }
Vector evaluateError(const T &params, Vector evaluateError(const T &params,
@ -108,8 +107,8 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
} }
virtual bool equals(const NonlinearFactor &other, double tol = 1e-9) const { virtual bool equals(const NonlinearFactor &other, double tol = 1e-9) const {
const FunctorizedFactor<T, R> *e = const FunctorizedFactor<R, T> *e =
dynamic_cast<const FunctorizedFactor<T, R> *>(&other); dynamic_cast<const FunctorizedFactor<R, T> *>(&other);
const bool base = Base::equals(*e, tol); const bool base = Base::equals(*e, tol);
return e && Base::equals(other, tol) && return e && Base::equals(other, tol) &&
traits<R>::Equals(this->measured_, e->measured_, tol); traits<R>::Equals(this->measured_, e->measured_, tol);
@ -129,8 +128,21 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
}; };
/// traits /// traits
template <typename T, typename R> template <typename R, typename T>
struct traits<FunctorizedFactor<T, R>> struct traits<FunctorizedFactor<R, T>>
: public Testable<FunctorizedFactor<T, R>> {}; : public Testable<FunctorizedFactor<R, T>> {};
/**
* Helper function to create a functorized factor.
*
* Uses function template deduction to identify return type and functor type, so
* template list only needs the functor argument type.
*/
template <typename T, typename R, typename FUNC>
FunctorizedFactor<R, T> MakeFunctorizedFactor(Key key, const R &z,
const SharedNoiseModel &model,
const FUNC func) {
return FunctorizedFactor<R, T>(key, z, model, func);
}
} // namespace gtsam } // namespace gtsam

View File

@ -34,9 +34,6 @@ class MultiplyFunctor {
double m_; ///< simple multiplier double m_; ///< simple multiplier
public: public:
using argument_type = Matrix;
using return_type = Matrix;
MultiplyFunctor(double m) : m_(m) {} MultiplyFunctor(double m) : m_(m) {}
Matrix operator()(const Matrix &X, Matrix operator()(const Matrix &X,
@ -47,12 +44,27 @@ class MultiplyFunctor {
}; };
/* ************************************************************************* */ /* ************************************************************************* */
// Test identity operation for FunctorizedFactor.
TEST(FunctorizedFactor, Identity) { TEST(FunctorizedFactor, Identity) {
Matrix X = Matrix::Identity(3, 3), measurement = Matrix::Identity(3, 3); Matrix X = Matrix::Identity(3, 3), measurement = Matrix::Identity(3, 3);
double multiplier = 1.0; double multiplier = 1.0;
auto functor = MultiplyFunctor(multiplier);
auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model, functor);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model, Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
}
/* ************************************************************************* */
// Test FunctorizedFactor with multiplier value of 2.
TEST(FunctorizedFactor, Multiply2) {
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
Matrix measurement = multiplier * Matrix::Identity(3, 3);
auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier)); MultiplyFunctor(multiplier));
Vector error = factor.evaluateError(X); Vector error = factor.evaluateError(X);
@ -61,41 +73,30 @@ TEST(FunctorizedFactor, Identity) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(FunctorizedFactor, Multiply2) { // Test equality function for FunctorizedFactor.
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
Matrix measurement = multiplier * Matrix::Identity(3, 3);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model,
MultiplyFunctor(multiplier));
Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
}
TEST(FunctorizedFactor, Equality) { TEST(FunctorizedFactor, Equality) {
Matrix measurement = Matrix::Identity(2, 2); Matrix measurement = Matrix::Identity(2, 2);
double multiplier = 2.0; double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor1(key, measurement, model, auto factor1 = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier)); MultiplyFunctor(multiplier));
FunctorizedFactor<Matrix, Matrix> factor2(key, measurement, model, auto factor2 = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier)); MultiplyFunctor(multiplier));
EXPECT(factor1.equals(factor2)); EXPECT(factor1.equals(factor2));
} }
//****************************************************************************** /* *************************************************************************** */
// Test Jacobians of FunctorizedFactor.
TEST(FunctorizedFactor, Jacobians) { TEST(FunctorizedFactor, Jacobians) {
Matrix X = Matrix::Identity(3, 3); Matrix X = Matrix::Identity(3, 3);
Matrix actualH; Matrix actualH;
double multiplier = 2.0; double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor(key, X, model, auto factor =
MultiplyFunctor(multiplier)); MakeFunctorizedFactor<Matrix>(key, X, model, MultiplyFunctor(multiplier));
Values values; Values values;
values.insert<Matrix>(key, X); values.insert<Matrix>(key, X);
@ -105,13 +106,14 @@ TEST(FunctorizedFactor, Jacobians) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Test print result of FunctorizedFactor.
TEST(FunctorizedFactor, Print) { TEST(FunctorizedFactor, Print) {
Matrix X = Matrix::Identity(2, 2); Matrix X = Matrix::Identity(2, 2);
double multiplier = 2.0; double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor(key, X, model, auto factor =
MultiplyFunctor(multiplier)); MakeFunctorizedFactor<Matrix>(key, X, model, MultiplyFunctor(multiplier));
// redirect output to buffer so we can compare // redirect output to buffer so we can compare
stringstream buffer; stringstream buffer;
@ -137,7 +139,7 @@ TEST(FunctorizedFactor, Print) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Test factor using a std::function type. // Test FunctorizedFactor using a std::function type.
TEST(FunctorizedFactor, Functional) { TEST(FunctorizedFactor, Functional) {
double multiplier = 2.0; double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3); Matrix X = Matrix::Identity(3, 3);
@ -145,7 +147,8 @@ TEST(FunctorizedFactor, Functional) {
std::function<Matrix(Matrix, boost::optional<Matrix &>)> functional = std::function<Matrix(Matrix, boost::optional<Matrix &>)> functional =
MultiplyFunctor(multiplier); MultiplyFunctor(multiplier);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model, functional); auto factor =
MakeFunctorizedFactor<Matrix>(key, measurement, model, functional);
Vector error = factor.evaluateError(X); Vector error = factor.evaluateError(X);
@ -153,6 +156,7 @@ TEST(FunctorizedFactor, Functional) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Test FunctorizedFactor with a lambda function.
TEST(FunctorizedFactor, Lambda) { TEST(FunctorizedFactor, Lambda) {
double multiplier = 2.0; double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3); Matrix X = Matrix::Identity(3, 3);
@ -166,15 +170,14 @@ TEST(FunctorizedFactor, Lambda) {
return multiplier * X; return multiplier * X;
}; };
// FunctorizedFactor<Matrix> factor(key, measurement, model, lambda); // FunctorizedFactor<Matrix> factor(key, measurement, model, lambda);
auto factor = FunctorizedFactor<Matrix>(key, measurement, model, lambda); auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model, lambda);
Vector error = factor.evaluateError(X); Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9)); EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
} }
/* ************************************************************************* /* ************************************************************************* */
*/
int main() { int main() {
TestResult tr; TestResult tr;
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);