Simplified FunctorizedFactor

By adding the helper function MakeFunctorizedFactor, we now only need to provide the argument type in the template parameter list. This considerably simplifies the factor declaration, while removing the need for argument type and return type in the functor definition.

Also added tests for std::function and lambda functions.
release/4.3a0
Varun Agrawal 2020-07-06 21:48:51 -04:00
parent 7d0e440293
commit 30ffcdd137
2 changed files with 50 additions and 35 deletions

View File

@ -25,8 +25,8 @@
namespace gtsam {
/**
* Factor which evaluates functor and uses the result to compute
* error on provided measurement.
* Factor which evaluates provided unary functor and uses the result to compute
* error with respect to the provided measurement.
*
* Template parameters are
* @param R: The return type of the functor after evaluation.
@ -40,13 +40,12 @@ namespace gtsam {
* class MultiplyFunctor {
* double m_; ///< simple multiplier
* public:
* using argument_type = Matrix;
* using return_type = Matrix;
* MultiplyFunctor(double m) : m_(m) {}
* Matrix operator()(const Matrix &X,
* OptionalJacobian<-1, -1> H = boost::none) const {
* if (H) *H = m_ * Matrix::Identity(X.rows()*X.cols(),
* X.rows()*X.cols()); return m_ * X;
* if (H)
* *H = m_ * Matrix::Identity(X.rows()*X.cols(), X.rows()*X.cols());
* return m_ * X;
* }
* };
*
@ -72,7 +71,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
/** Construct with given x and the parameters of the basis
*
* @param key: Factor key
* @param z: Measurement object of type R
* @param z: Measurement object of same type as that returned by functor
* @param model: Noise model
* @param func: The instance of the functor object
*/
@ -85,7 +84,7 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
/// @return a deep copy of this factor
virtual NonlinearFactor::shared_ptr clone() const {
return boost::static_pointer_cast<NonlinearFactor>(
NonlinearFactor::shared_ptr(new FunctorizedFactor<T, R>(*this)));
NonlinearFactor::shared_ptr(new FunctorizedFactor<R, T>(*this)));
}
Vector evaluateError(const T &params,
@ -108,8 +107,8 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
}
virtual bool equals(const NonlinearFactor &other, double tol = 1e-9) const {
const FunctorizedFactor<T, R> *e =
dynamic_cast<const FunctorizedFactor<T, R> *>(&other);
const FunctorizedFactor<R, T> *e =
dynamic_cast<const FunctorizedFactor<R, T> *>(&other);
const bool base = Base::equals(*e, tol);
return e && Base::equals(other, tol) &&
traits<R>::Equals(this->measured_, e->measured_, tol);
@ -129,8 +128,21 @@ class GTSAM_EXPORT FunctorizedFactor : public NoiseModelFactor1<T> {
};
/// traits
template <typename T, typename R>
struct traits<FunctorizedFactor<T, R>>
: public Testable<FunctorizedFactor<T, R>> {};
template <typename R, typename T>
struct traits<FunctorizedFactor<R, T>>
: public Testable<FunctorizedFactor<R, T>> {};
/**
* Helper function to create a functorized factor.
*
* Uses function template deduction to identify return type and functor type, so
* template list only needs the functor argument type.
*/
template <typename T, typename R, typename FUNC>
FunctorizedFactor<R, T> MakeFunctorizedFactor(Key key, const R &z,
const SharedNoiseModel &model,
const FUNC func) {
return FunctorizedFactor<R, T>(key, z, model, func);
}
} // namespace gtsam

View File

@ -34,9 +34,6 @@ class MultiplyFunctor {
double m_; ///< simple multiplier
public:
using argument_type = Matrix;
using return_type = Matrix;
MultiplyFunctor(double m) : m_(m) {}
Matrix operator()(const Matrix &X,
@ -47,12 +44,27 @@ class MultiplyFunctor {
};
/* ************************************************************************* */
// Test identity operation for FunctorizedFactor.
TEST(FunctorizedFactor, Identity) {
Matrix X = Matrix::Identity(3, 3), measurement = Matrix::Identity(3, 3);
double multiplier = 1.0;
auto functor = MultiplyFunctor(multiplier);
auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model, functor);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model,
Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
}
/* ************************************************************************* */
// Test FunctorizedFactor with multiplier value of 2.
TEST(FunctorizedFactor, Multiply2) {
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
Matrix measurement = multiplier * Matrix::Identity(3, 3);
auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier));
Vector error = factor.evaluateError(X);
@ -61,41 +73,30 @@ TEST(FunctorizedFactor, Identity) {
}
/* ************************************************************************* */
TEST(FunctorizedFactor, Multiply2) {
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
Matrix measurement = multiplier * Matrix::Identity(3, 3);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model,
MultiplyFunctor(multiplier));
Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
}
// Test equality function for FunctorizedFactor.
TEST(FunctorizedFactor, Equality) {
Matrix measurement = Matrix::Identity(2, 2);
double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor1(key, measurement, model,
auto factor1 = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier));
FunctorizedFactor<Matrix, Matrix> factor2(key, measurement, model,
auto factor2 = MakeFunctorizedFactor<Matrix>(key, measurement, model,
MultiplyFunctor(multiplier));
EXPECT(factor1.equals(factor2));
}
//******************************************************************************
/* *************************************************************************** */
// Test Jacobians of FunctorizedFactor.
TEST(FunctorizedFactor, Jacobians) {
Matrix X = Matrix::Identity(3, 3);
Matrix actualH;
double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor(key, X, model,
MultiplyFunctor(multiplier));
auto factor =
MakeFunctorizedFactor<Matrix>(key, X, model, MultiplyFunctor(multiplier));
Values values;
values.insert<Matrix>(key, X);
@ -105,13 +106,14 @@ TEST(FunctorizedFactor, Jacobians) {
}
/* ************************************************************************* */
// Test print result of FunctorizedFactor.
TEST(FunctorizedFactor, Print) {
Matrix X = Matrix::Identity(2, 2);
double multiplier = 2.0;
FunctorizedFactor<Matrix, Matrix> factor(key, X, model,
MultiplyFunctor(multiplier));
auto factor =
MakeFunctorizedFactor<Matrix>(key, X, model, MultiplyFunctor(multiplier));
// redirect output to buffer so we can compare
stringstream buffer;
@ -137,7 +139,7 @@ TEST(FunctorizedFactor, Print) {
}
/* ************************************************************************* */
// Test factor using a std::function type.
// Test FunctorizedFactor using a std::function type.
TEST(FunctorizedFactor, Functional) {
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
@ -145,7 +147,8 @@ TEST(FunctorizedFactor, Functional) {
std::function<Matrix(Matrix, boost::optional<Matrix &>)> functional =
MultiplyFunctor(multiplier);
FunctorizedFactor<Matrix, Matrix> factor(key, measurement, model, functional);
auto factor =
MakeFunctorizedFactor<Matrix>(key, measurement, model, functional);
Vector error = factor.evaluateError(X);
@ -153,6 +156,7 @@ TEST(FunctorizedFactor, Functional) {
}
/* ************************************************************************* */
// Test FunctorizedFactor with a lambda function.
TEST(FunctorizedFactor, Lambda) {
double multiplier = 2.0;
Matrix X = Matrix::Identity(3, 3);
@ -166,15 +170,14 @@ TEST(FunctorizedFactor, Lambda) {
return multiplier * X;
};
// FunctorizedFactor<Matrix> factor(key, measurement, model, lambda);
auto factor = FunctorizedFactor<Matrix>(key, measurement, model, lambda);
auto factor = MakeFunctorizedFactor<Matrix>(key, measurement, model, lambda);
Vector error = factor.evaluateError(X);
EXPECT(assert_equal(Vector::Zero(9), error, 1e-9));
}
/* *************************************************************************
*/
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);