diff --git a/gtsam/base/GenericValue.h b/gtsam/base/GenericValue.h index 7a48d85c3..1c01f7bb7 100644 --- a/gtsam/base/GenericValue.h +++ b/gtsam/base/GenericValue.h @@ -19,7 +19,10 @@ #pragma once +#include #include +#include +#include namespace gtsam { @@ -28,15 +31,63 @@ namespace traits { // trait to wrap the default equals of types template -bool equals(const T& a, const T& b, double tol) { - return a.equals(b, tol); -} +struct equals { + typedef T type; + typedef bool result_type; + bool operator()(const T& a, const T& b, double tol) { + return a.equals(b, tol); + } +}; // trait to wrap the default print of types template -void print(const T& obj, const std::string& str) { - obj.print(str); -} +struct print { + typedef T type; + typedef void result_type; + void operator()(const T& obj, const std::string& str) { + obj.print(str); + } +}; + +// equals for scalars +template<> +struct equals { + typedef double type; + typedef bool result_type; + bool operator()(double a, double b, double tol) { + return std::abs(a - b) <= tol; + } +}; + +// print for scalars +template<> +struct print { + typedef double type; + typedef void result_type; + void operator()(double a, const std::string& str) { + std::cout << str << ": " << a << std::endl; + } +}; + +// equals for Matrix types +template +struct equals > { + typedef Eigen::Matrix type; + typedef bool result_type; + bool operator()(const type& A, const type& B, double tol) { + return equal_with_abs_tol(A, B, tol); + } +}; + +// print for Matrix types +template +struct print > { + typedef Eigen::Matrix type; + typedef void result_type; + void operator()(const type& A, const std::string& str) { + std::cout << str << ": " << A << std::endl; + } +}; } @@ -80,17 +131,17 @@ public: // Cast the base class Value pointer to a templated generic class pointer const GenericValue& genericValue2 = static_cast(p); // Return the result of using the equals traits for the derived class - return traits::equals(this->value_, genericValue2.value_, tol); + return traits::equals()(this->value_, genericValue2.value_, tol); } /// non virtual equals function, uses traits bool equals(const GenericValue &other, double tol = 1e-9) const { - return traits::equals(this->value(), other.value(), tol); + return traits::equals()(this->value(), other.value(), tol); } /// Virtual print function, uses traits virtual void print(const std::string& str) const { - traits::print(value_, str); + traits::print()(value_, str); } // Serialization below: