Add variadic-template N-ary ExpressionFactor

release/4.3a0
Jose Luis Blanco-Claraco 2020-07-27 11:57:31 +02:00 committed by Jose Luis Blanco Claraco
parent 8b1f3e1745
commit 947479e9de
No known key found for this signature in database
GPG Key ID: D443304FBD70A641
2 changed files with 167 additions and 2 deletions

View File

@ -19,9 +19,10 @@
#pragma once
#include <array>
#include <gtsam/base/Testable.h>
#include <gtsam/nonlinear/Expression.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/base/Testable.h>
#include <numeric>
namespace gtsam {
@ -287,4 +288,71 @@ template <typename T, typename A1, typename A2>
struct traits<ExpressionFactor2<T,A1,A2>> : public Testable<ExpressionFactor2<T,A1,A2>> {};
// ExpressionFactor2
}// \ namespace gtsam
/**
* N-ary variadic template for ExpressionFactor meant as a base class for N-ary
* factors. Enforces an 'expression' method with N keys.
* Derived class (an N-factor!) needs to call 'initialize'.
*
* Does not provide backward compatible 'evaluateError'.
*
* \tparam T Type for measurements. The rest of template arguments are types
* for the N key-indexed Values.
*
*/
template <typename T, typename... Args>
class ExpressionFactorN : public ExpressionFactor<T> {
public:
static const std::size_t NARY_EXPRESSION_SIZE = sizeof...(Args);
using ArrayNKeys = std::array<Key, NARY_EXPRESSION_SIZE>;
/// Destructor
virtual ~ExpressionFactorN() = default;
// Don't provide backward compatible evaluateVector(), due to its problematic
// variable length of optional Jacobian arguments. Vector evaluateError(const
// Args... args,...);
/// Recreate expression from given keys_ and measured_, used in load
/// Needed to deserialize a derived factor
virtual Expression<T> expression(const ArrayNKeys &keys) const {
throw std::runtime_error(
"ExpressionFactorN::expression not provided: cannot deserialize.");
}
protected:
/// Default constructor, for serialization
ExpressionFactorN() = default;
/// Constructor takes care of keys, but still need to call initialize
ExpressionFactorN(const ArrayNKeys &keys, const SharedNoiseModel &noiseModel,
const T &measurement)
: ExpressionFactor<T>(noiseModel, measurement) {
for (const auto &key : keys)
Factor::keys_.push_back(key);
}
private:
/// Return an expression that predicts the measurement given Values
Expression<T> expression() const override {
ArrayNKeys keys;
int idx = 0;
for (const auto &key : Factor::keys_)
keys[idx++] = key;
return expression(keys);
}
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &boost::serialization::make_nvp(
"ExpressionFactorN",
boost::serialization::base_object<ExpressionFactor<T>>(*this));
}
};
/// traits
template <typename T, typename... Args>
struct traits<ExpressionFactor2<T, Args...>>
: public Testable<ExpressionFactorN<T, Args...>> {};
// ExpressionFactorN
} // namespace gtsam

View File

@ -630,6 +630,103 @@ TEST(ExpressionFactor, MultiplyWithInverseFunction) {
EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-5, 1e-5);
}
/* ************************************************************************* */
// Test N-ary variadic template
class TestNaryFactor
: public gtsam::ExpressionFactorN<gtsam::Point3 /*return type*/,
gtsam::Rot3, gtsam::Point3,
gtsam::Rot3,gtsam::Point3> {
private:
using This = TestNaryFactor;
using Base =
gtsam::ExpressionFactorN<gtsam::Point3 /*return type*/,
gtsam::Rot3, gtsam::Point3, gtsam::Rot3, gtsam::Point3>;
public:
/// default constructor
TestNaryFactor() = default;
~TestNaryFactor() override = default;
TestNaryFactor(gtsam::Key kR1, gtsam::Key kV1, gtsam::Key kR2, gtsam::Key kV2,
const gtsam::SharedNoiseModel &model, const gtsam::Point3& measured)
: Base({kR1, kV1, kR2, kV2}, model, measured) {
this->initialize(expression({kR1, kV1, kR2, kV2}));
}
/// @return a deep copy of this factor
gtsam::NonlinearFactor::shared_ptr clone() const override {
return boost::static_pointer_cast<gtsam::NonlinearFactor>(
gtsam::NonlinearFactor::shared_ptr(new This(*this)));
}
// Return measurement expression
gtsam::Expression<gtsam::Point3> expression(
const std::array<gtsam::Key, NARY_EXPRESSION_SIZE> &keys) const override {
gtsam::Expression<gtsam::Rot3> R1_(keys[0]);
gtsam::Expression<gtsam::Point3> V1_(keys[1]);
gtsam::Expression<gtsam::Rot3> R2_(keys[2]);
gtsam::Expression<gtsam::Point3> V2_(keys[3]);
return {gtsam::rotate(R1_, V1_) - gtsam::rotate(R2_, V2_)};
}
/** print */
void print(const std::string &s,
const gtsam::KeyFormatter &keyFormatter =
gtsam::DefaultKeyFormatter) const override {
std::cout << s << "TestNaryFactor("
<< keyFormatter(Factor::keys_[0]) << ","
<< keyFormatter(Factor::keys_[1]) << ","
<< keyFormatter(Factor::keys_[2]) << ","
<< keyFormatter(Factor::keys_[3]) << ")\n";
gtsam::traits<gtsam::Point3>::Print(measured_, " measured: ");
this->noiseModel_->print(" noise model: ");
}
/** equals */
bool equals(const gtsam::NonlinearFactor &expected,
double tol = 1e-9) const override {
const This *e = dynamic_cast<const This *>(&expected);
return e != nullptr && Base::equals(*e, tol) &&
gtsam::traits<gtsam::Point3>::Equals(measured_,e->measured_, tol);
}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &boost::serialization::make_nvp(
"TestNaryFactor",
boost::serialization::base_object<Base>(*this));
ar &BOOST_SERIALIZATION_NVP(measured_);
}
};
TEST(ExpressionFactor, variadicTemplate) {
using gtsam::symbol_shorthand::R;
using gtsam::symbol_shorthand::V;
// Create factor
TestNaryFactor f(R(0),V(0), R(1), V(1), noiseModel::Unit::Create(3), Point3(0,0,0));
// Create some values
Values values;
values.insert(R(0), Rot3::Ypr(0.1, 0.2, 0.3));
values.insert(V(0), Point3(1, 2, 3));
values.insert(R(1), Rot3::Ypr(0.2, 0.5, 0.2));
values.insert(V(1), Point3(5, 6, 7));
// Check unwhitenedError
std::vector<Matrix> H(4);
Vector actual = f.unwhitenedError(values, H);
EXPECT_LONGS_EQUAL(4, H.size());
EXPECT(assert_equal(Eigen::Vector3d(-5.63578115, -4.85353243, -1.4801204), actual, 1e-5));
EXPECT_CORRECT_FACTOR_JACOBIANS(f, values, 1e-8, 1e-5);
}
/* ************************************************************************* */
int main() {
TestResult tr;