rewrite evaluateError to use SFINAE instead of conditional compilation

release/4.3a0
kartik arcot 2023-01-17 13:45:46 -08:00
parent 65bb6aea63
commit 92874f76fa
1 changed files with 41 additions and 20 deletions

View File

@ -30,6 +30,7 @@
#include <boost/serialization/base_object.hpp>
#include <cstddef>
#include <type_traits>
namespace gtsam {
@ -456,6 +457,28 @@ protected:
template <typename Container>
using IsContainerOfKeys = IsConvertible<ContainerElementType<Container>, Key>;
/** A helper alias to check if a list of args
* are all references to a matrix or not. It will be used
* to choose the right overload of evaluateError.
*/
template <typename Ret, typename ...Args>
using AreAllMatrixRefs = std::enable_if_t<(... &&
std::is_convertible<Args, Matrix&>::value), Ret>;
template<typename Arg>
using IsMatrixPointer = std::is_same<typename std::decay_t<Arg>, Matrix*>;
template<typename Arg>
using IsNullpointer = std::is_same<typename std::decay_t<Arg>, std::nullptr_t>;
/** A helper alias to check if a list of args
* are all pointers to a matrix or not. It will be used
* to choose the right overload of evaluateError.
*/
template <typename Ret, typename ...Args>
using AreAllMatrixPtrs = std::enable_if_t<(... &&
(IsMatrixPointer<Args>::value || IsNullpointer<Args>::value)), Ret>;
/// @}
/* Like std::void_t, except produces `OptionalMatrixType` instead of
@ -622,7 +645,6 @@ protected:
* public:
* using NoiseModelFactorN<list the value types here>::evaluateError;
*/
Vector evaluateError(const ValueTypes&... x, MatrixTypeT<ValueTypes>&... H) const {
return evaluateError(x..., (&H)...);
}
@ -642,28 +664,27 @@ protected:
}
/** Some (but not all) optional Jacobians are omitted (function overload)
*
* and the jacobians are l-value references to matrices.
* e.g. `const Vector error = factor.evaluateError(pose, point, Hpose);`
*/
template <typename... OptionalJacArgs, typename = IndexIsValid<sizeof...(OptionalJacArgs) + 1>>
inline Vector evaluateError(const ValueTypes&... x, OptionalJacArgs&&... H) const {
// A check to ensure all arguments passed are either matrices or are all pointers to matrices
constexpr bool are_all_mat = (... && (std::is_same<Matrix, std::decay_t<OptionalJacArgs>>::value));
// The pointers can either be of std::nonetype_t or of Matrix* type
constexpr bool are_all_ptrs = (... && (std::is_same<OptionalMatrixType, std::decay_t<OptionalJacArgs>>::value ||
std::is_same<std::nullptr_t, std::decay_t<OptionalJacArgs>>::value));
static_assert((are_all_mat || are_all_ptrs),
"Arguments that are passed to the evaluateError function can only be of following the types: Matrix, "
"or Matrix*");
// If they pass all matrices then we want to pass their pointers instead
if constexpr (are_all_mat) {
inline AreAllMatrixRefs<Vector, OptionalJacArgs...> evaluateError(const ValueTypes&... x,
OptionalJacArgs&&... H) const {
return evaluateError(x..., (&H)...);
} else {
}
/** Some (but not all) optional Jacobians are omitted (function overload)
* and the jacobians are pointers to matrices.
* e.g. `const Vector error = factor.evaluateError(pose, point, &Hpose);`
*/
template <typename... OptionalJacArgs, typename = IndexIsValid<sizeof...(OptionalJacArgs) + 1>>
inline AreAllMatrixPtrs<Vector, OptionalJacArgs...> evaluateError(const ValueTypes&... x,
OptionalJacArgs&&... H) const {
// If they are pointer version, ensure to cast them all to be Matrix* types
// This will ensure any arguments inferred as std::nonetype_t are cast to (Matrix*) nullptr
// This guides the compiler to the correct overload which is the one that takes pointers
return evaluateError(x..., std::forward<OptionalJacArgs>(H)..., static_cast<OptionalMatrixType>(OptionalNone));
}
return evaluateError(x...,
std::forward<OptionalJacArgs>(H)..., static_cast<OptionalMatrixType>(OptionalNone));
}
/// @}