add comparator as argument to equals method
# Conflicts: # gtsam/hybrid/DCMixtureFactor.hrelease/4.3a0
parent
9982057a2b
commit
b24da8399a
|
|
@ -30,6 +30,13 @@ namespace gtsam {
|
||||||
template<typename L>
|
template<typename L>
|
||||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Default method for comparison of two doubles upto some tolerance.
|
||||||
|
*/
|
||||||
|
static bool DefaultComparator(double a, double b, double tol) {
|
||||||
|
return std::abs(a - b) < tol;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef DecisionTree<L, double> Super;
|
typedef DecisionTree<L, double> Super;
|
||||||
|
|
@ -138,6 +145,12 @@ namespace gtsam {
|
||||||
return this->combine(labelC, &Ring::add);
|
return this->combine(labelC, &Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Equality method customized to node type `double`.
|
||||||
|
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9,
|
||||||
|
const std::function<bool(double, double, double)>& comparator =
|
||||||
|
&DefaultComparator) const {
|
||||||
|
return this->root_->equals(*other.root_, tol, comparator);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
// AlgebraicDecisionTree
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/VectorSpace.h>
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
|
@ -77,33 +76,13 @@ namespace gtsam {
|
||||||
return (q.isLeaf() && q.sameLeaf(*this));
|
return (q.isLeaf() && q.sameLeaf(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @{
|
|
||||||
/// SFINAE methods for proper substitution.
|
|
||||||
/** equality for integral types. */
|
|
||||||
template <typename T = Y>
|
|
||||||
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
|
||||||
equals(const T& a, const T& b, double tol) const {
|
|
||||||
return std::abs(double(a - b)) < tol;
|
|
||||||
}
|
|
||||||
/** equality for boost::shared_ptr types. */
|
|
||||||
template <typename T = Y>
|
|
||||||
typename std::enable_if<boost::has_dereference<T>::value, bool>::type
|
|
||||||
equals(const T& a, const T& b, double tol) const {
|
|
||||||
return traits<typename T::element_type>::Equals(*a, *b, tol);
|
|
||||||
}
|
|
||||||
/** equality for all other types. */
|
|
||||||
template <typename T = Y>
|
|
||||||
typename std::enable_if<!boost::has_dereference<T>::value && !std::is_integral<T>::value, bool>::type
|
|
||||||
equals(const Y& a, const Y& b, double tol) const {
|
|
||||||
return traits<Y>::Equals(a, b, tol);
|
|
||||||
}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol) const override {
|
bool equals(const Node& q, double tol,
|
||||||
|
const std::function<bool(const Y&, const Y&, double)>&
|
||||||
|
comparator) const override {
|
||||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
return this->equals<Y>(this->constant_, other->constant_, tol);
|
return comparator(this->constant_, other->constant_, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
|
|
@ -304,14 +283,17 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol) const override {
|
bool equals(const Node& q, double tol,
|
||||||
const Choice* other = dynamic_cast<const Choice*> (&q);
|
const std::function<bool(const Y&, const Y&, double)>&
|
||||||
|
comparator) const override {
|
||||||
|
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
if (branches_.size() != other->branches_.size()) return false;
|
if (branches_.size() != other->branches_.size()) return false;
|
||||||
// we don't care about shared pointers being equal here
|
// we don't care about shared pointers being equal here
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
|
if (!(branches_[i]->equals(*(other->branches_[i]), tol, comparator)))
|
||||||
|
return false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -668,9 +650,11 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
|
bool DecisionTree<L, Y>::equals(
|
||||||
return root_->equals(*other.root_, tol);
|
const DecisionTree& other, double tol,
|
||||||
|
const std::function<bool(const Y&, const Y&, double)>& comparator) const {
|
||||||
|
return root_->equals(*other.root_, tol, comparator);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
|
||||||
|
|
@ -39,13 +39,18 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class GTSAM_EXPORT DecisionTree {
|
class GTSAM_EXPORT DecisionTree {
|
||||||
|
|
||||||
/// default method used by `formatter` when printing.
|
/// Default method used by `formatter` when printing.
|
||||||
static std::string DefaultFormatter(const L& x) {
|
static std::string DefaultFormatter(const L& x) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << x;
|
ss << x;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default method for comparison of two objects of type Y.
|
||||||
|
static bool DefaultComparator(const Y& a, const Y& b, double tol) {
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** Handy typedefs for unary and binary function types */
|
/** Handy typedefs for unary and binary function types */
|
||||||
|
|
@ -92,7 +97,9 @@ namespace gtsam {
|
||||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
||||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||||
virtual bool sameLeaf(const Node& q) const = 0;
|
virtual bool sameLeaf(const Node& q) const = 0;
|
||||||
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
|
virtual bool equals(const Node& other, double tol = 1e-9,
|
||||||
|
const std::function<bool(const Y&, const Y&, double)>&
|
||||||
|
comparator = &DefaultComparator) const = 0;
|
||||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||||
virtual Ptr apply(const Unary& op) const = 0;
|
virtual Ptr apply(const Unary& op) const = 0;
|
||||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||||
|
|
@ -178,7 +185,9 @@ namespace gtsam {
|
||||||
&DefaultFormatter) const;
|
&DefaultFormatter) const;
|
||||||
|
|
||||||
// Testable
|
// Testable
|
||||||
bool equals(const DecisionTree& other, double tol = 1e-9) const;
|
bool equals(const DecisionTree& other, double tol = 1e-9,
|
||||||
|
const std::function<bool(const Y&, const Y&, double)>&
|
||||||
|
comparator = &DefaultComparator) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,19 @@ void dot(const T&f, const string& filename) {
|
||||||
|
|
||||||
#define DOT(x)(dot(x,#x))
|
#define DOT(x)(dot(x,#x))
|
||||||
|
|
||||||
struct Crazy { int a; double b; };
|
struct Crazy {
|
||||||
|
int a;
|
||||||
|
double b;
|
||||||
|
|
||||||
|
bool equals(const Crazy& other, double tol = 1e-12) const {
|
||||||
|
return a == other.a && std::abs(b - other.b) < tol;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const Crazy& other) const {
|
||||||
|
return this->equals(other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue