diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 9cc55ed6a..3469ac5cf 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -30,6 +30,13 @@ namespace gtsam { template class AlgebraicDecisionTree: public DecisionTree { + /** + * @brief Default method for comparison of two doubles upto some tolerance. + */ + static bool DefaultComparator(double a, double b, double tol) { + return std::abs(a - b) < tol; + } + public: typedef DecisionTree Super; @@ -138,6 +145,12 @@ namespace gtsam { return this->combine(labelC, &Ring::add); } + /// Equality method customized to node type `double`. + bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9, + const std::function& comparator = + &DefaultComparator) const { + return this->root_->equals(*other.root_, tol, comparator); + } }; // AlgebraicDecisionTree diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index ec6222fd7..f531e2b98 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -20,7 +20,6 @@ #pragma once #include -#include #include #include @@ -77,33 +76,13 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /// @{ - /// SFINAE methods for proper substitution. - /** equality for integral types. */ - template - typename std::enable_if::value, bool>::type - equals(const T& a, const T& b, double tol) const { - return std::abs(double(a - b)) < tol; - } - /** equality for boost::shared_ptr types. */ - template - typename std::enable_if::value, bool>::type - equals(const T& a, const T& b, double tol) const { - return traits::Equals(*a, *b, tol); - } - /** equality for all other types. */ - template - typename std::enable_if::value && !std::is_integral::value, bool>::type - equals(const Y& a, const Y& b, double tol) const { - return traits::Equals(a, b, tol); - } - /// @} - /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { + bool equals(const Node& q, double tol, + const std::function& + comparator) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; - return this->equals(this->constant_, other->constant_, tol); + return comparator(this->constant_, other->constant_, tol); } /** print */ @@ -304,14 +283,17 @@ namespace gtsam { } /** equality up to tolerance */ - bool equals(const Node& q, double tol) const override { - const Choice* other = dynamic_cast (&q); + bool equals(const Node& q, double tol, + const std::function& + comparator) const override { + const Choice* other = dynamic_cast(&q); if (!other) return false; if (this->label_ != other->label_) return false; if (branches_.size() != other->branches_.size()) return false; // we don't care about shared pointers being equal here for (size_t i = 0; i < branches_.size(); i++) - if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + if (!(branches_[i]->equals(*(other->branches_[i]), tol, comparator))) + return false; return true; } @@ -668,9 +650,11 @@ namespace gtsam { } /*********************************************************************************/ - template - bool DecisionTree::equals(const DecisionTree& other, double tol) const { - return root_->equals(*other.root_, tol); + template + bool DecisionTree::equals( + const DecisionTree& other, double tol, + const std::function& comparator) const { + return root_->equals(*other.root_, tol, comparator); } template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 498fce5aa..eb23bc5ce 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -39,13 +39,18 @@ namespace gtsam { template class GTSAM_EXPORT DecisionTree { - /// default method used by `formatter` when printing. + /// Default method used by `formatter` when printing. static std::string DefaultFormatter(const L& x) { std::stringstream ss; ss << x; return ss.str(); } + /// Default method for comparison of two objects of type Y. + static bool DefaultComparator(const Y& a, const Y& b, double tol) { + return a == b; + } + public: /** Handy typedefs for unary and binary function types */ @@ -92,7 +97,9 @@ namespace gtsam { virtual void dot(std::ostream& os, bool showZero) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0; - virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual bool equals(const Node& other, double tol = 1e-9, + const std::function& + comparator = &DefaultComparator) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; @@ -178,7 +185,9 @@ namespace gtsam { &DefaultFormatter) const; // Testable - bool equals(const DecisionTree& other, double tol = 1e-9) const; + bool equals(const DecisionTree& other, double tol = 1e-9, + const std::function& + comparator = &DefaultComparator) const; /// @} /// @name Standard Interface diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 96f503abc..dbf6cc44b 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -40,7 +40,19 @@ void dot(const T&f, const string& filename) { #define DOT(x)(dot(x,#x)) -struct Crazy { int a; double b; }; +struct Crazy { + int a; + double b; + + bool equals(const Crazy& other, double tol = 1e-12) const { + return a == other.a && std::abs(b - other.b) < tol; + } + + bool operator==(const Crazy& other) const { + return this->equals(other); + } +}; + typedef DecisionTree CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) // traits