diff --git a/gtsam/nonlinear/ExpressionFactor.h b/gtsam/nonlinear/ExpressionFactor.h index 85527f6dc..41eb1642e 100644 --- a/gtsam/nonlinear/ExpressionFactor.h +++ b/gtsam/nonlinear/ExpressionFactor.h @@ -19,9 +19,10 @@ #pragma once +#include +#include #include #include -#include #include namespace gtsam { @@ -287,4 +288,71 @@ template struct traits> : public Testable> {}; // ExpressionFactor2 -}// \ namespace gtsam +/** + * N-ary variadic template for ExpressionFactor meant as a base class for N-ary + * factors. Enforces an 'expression' method with N keys. + * Derived class (an N-factor!) needs to call 'initialize'. + * + * Does not provide backward compatible 'evaluateError'. + * + * \tparam T Type for measurements. The rest of template arguments are types + * for the N key-indexed Values. + * + */ +template +class ExpressionFactorN : public ExpressionFactor { +public: + static const std::size_t NARY_EXPRESSION_SIZE = sizeof...(Args); + using ArrayNKeys = std::array; + + /// Destructor + virtual ~ExpressionFactorN() = default; + + // Don't provide backward compatible evaluateVector(), due to its problematic + // variable length of optional Jacobian arguments. Vector evaluateError(const + // Args... args,...); + + /// Recreate expression from given keys_ and measured_, used in load + /// Needed to deserialize a derived factor + virtual Expression expression(const ArrayNKeys &keys) const { + throw std::runtime_error( + "ExpressionFactorN::expression not provided: cannot deserialize."); + } + +protected: + /// Default constructor, for serialization + ExpressionFactorN() = default; + + /// Constructor takes care of keys, but still need to call initialize + ExpressionFactorN(const ArrayNKeys &keys, const SharedNoiseModel &noiseModel, + const T &measurement) + : ExpressionFactor(noiseModel, measurement) { + for (const auto &key : keys) + Factor::keys_.push_back(key); + } + +private: + /// Return an expression that predicts the measurement given Values + Expression expression() const override { + ArrayNKeys keys; + int idx = 0; + for (const auto &key : Factor::keys_) + keys[idx++] = key; + return expression(keys); + } + + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &boost::serialization::make_nvp( + "ExpressionFactorN", + boost::serialization::base_object>(*this)); + } +}; +/// traits +template +struct traits> + : public Testable> {}; +// ExpressionFactorN + +} // namespace gtsam diff --git a/tests/testExpressionFactor.cpp b/tests/testExpressionFactor.cpp index d33c7ba1d..e3e37e7c7 100644 --- a/tests/testExpressionFactor.cpp +++ b/tests/testExpressionFactor.cpp @@ -630,6 +630,103 @@ TEST(ExpressionFactor, MultiplyWithInverseFunction) { EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-5, 1e-5); } + +/* ************************************************************************* */ +// Test N-ary variadic template +class TestNaryFactor + : public gtsam::ExpressionFactorN { +private: + using This = TestNaryFactor; + using Base = + gtsam::ExpressionFactorN; + +public: + /// default constructor + TestNaryFactor() = default; + ~TestNaryFactor() override = default; + + TestNaryFactor(gtsam::Key kR1, gtsam::Key kV1, gtsam::Key kR2, gtsam::Key kV2, + const gtsam::SharedNoiseModel &model, const gtsam::Point3& measured) + : Base({kR1, kV1, kR2, kV2}, model, measured) { + this->initialize(expression({kR1, kV1, kR2, kV2})); + } + + /// @return a deep copy of this factor + gtsam::NonlinearFactor::shared_ptr clone() const override { + return boost::static_pointer_cast( + gtsam::NonlinearFactor::shared_ptr(new This(*this))); + } + + // Return measurement expression + gtsam::Expression expression( + const std::array &keys) const override { + gtsam::Expression R1_(keys[0]); + gtsam::Expression V1_(keys[1]); + gtsam::Expression R2_(keys[2]); + gtsam::Expression V2_(keys[3]); + return {gtsam::rotate(R1_, V1_) - gtsam::rotate(R2_, V2_)}; + } + + /** print */ + void print(const std::string &s, + const gtsam::KeyFormatter &keyFormatter = + gtsam::DefaultKeyFormatter) const override { + std::cout << s << "TestNaryFactor(" + << keyFormatter(Factor::keys_[0]) << "," + << keyFormatter(Factor::keys_[1]) << "," + << keyFormatter(Factor::keys_[2]) << "," + << keyFormatter(Factor::keys_[3]) << ")\n"; + gtsam::traits::Print(measured_, " measured: "); + this->noiseModel_->print(" noise model: "); + } + + /** equals */ + bool equals(const gtsam::NonlinearFactor &expected, + double tol = 1e-9) const override { + const This *e = dynamic_cast(&expected); + return e != nullptr && Base::equals(*e, tol) && + gtsam::traits::Equals(measured_,e->measured_, tol); + } + +private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &boost::serialization::make_nvp( + "TestNaryFactor", + boost::serialization::base_object(*this)); + ar &BOOST_SERIALIZATION_NVP(measured_); + } +}; + +TEST(ExpressionFactor, variadicTemplate) { + using gtsam::symbol_shorthand::R; + using gtsam::symbol_shorthand::V; + + // Create factor + TestNaryFactor f(R(0),V(0), R(1), V(1), noiseModel::Unit::Create(3), Point3(0,0,0)); + + // Create some values + Values values; + values.insert(R(0), Rot3::Ypr(0.1, 0.2, 0.3)); + values.insert(V(0), Point3(1, 2, 3)); + values.insert(R(1), Rot3::Ypr(0.2, 0.5, 0.2)); + values.insert(V(1), Point3(5, 6, 7)); + + // Check unwhitenedError + std::vector H(4); + Vector actual = f.unwhitenedError(values, H); + EXPECT_LONGS_EQUAL(4, H.size()); + EXPECT(assert_equal(Eigen::Vector3d(-5.63578115, -4.85353243, -1.4801204), actual, 1e-5)); + + EXPECT_CORRECT_FACTOR_JACOBIANS(f, values, 1e-8, 1e-5); +} + + /* ************************************************************************* */ int main() { TestResult tr;