address more review comments
# Conflicts: # gtsam/hybrid/DCGaussianMixtureFactor.h # gtsam/hybrid/DCMixtureFactor.hrelease/4.3a0
parent
b24da8399a
commit
26c48a8c1f
|
|
@ -30,13 +30,6 @@ namespace gtsam {
|
||||||
template<typename L>
|
template<typename L>
|
||||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Default method for comparison of two doubles upto some tolerance.
|
|
||||||
*/
|
|
||||||
static bool DefaultComparator(double a, double b, double tol) {
|
|
||||||
return std::abs(a - b) < tol;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef DecisionTree<L, double> Super;
|
typedef DecisionTree<L, double> Super;
|
||||||
|
|
@ -146,9 +139,11 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Equality method customized to node type `double`.
|
/// Equality method customized to node type `double`.
|
||||||
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9,
|
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
||||||
const std::function<bool(double, double, double)>& comparator =
|
// lambda for comparison of two doubles upto some tolerance.
|
||||||
&DefaultComparator) const {
|
auto comparator = [](double a, double b, double tol) {
|
||||||
|
return std::abs(a - b) < tol;
|
||||||
|
};
|
||||||
return this->root_->equals(*other.root_, tol, comparator);
|
return this->root_->equals(*other.root_, tol, comparator);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
|
@ -78,8 +77,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol,
|
bool equals(const Node& q, double tol,
|
||||||
const std::function<bool(const Y&, const Y&, double)>&
|
const ComparatorFunc& comparator) const override {
|
||||||
comparator) const override {
|
|
||||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
return comparator(this->constant_, other->constant_, tol);
|
return comparator(this->constant_, other->constant_, tol);
|
||||||
|
|
@ -87,7 +85,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** print */
|
/** print */
|
||||||
void print(const std::string& s,
|
void print(const std::string& s,
|
||||||
const std::function<std::string(L)>& formatter) const override {
|
const FormatterFunc& formatter) const override {
|
||||||
bool showZero = true;
|
bool showZero = true;
|
||||||
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
||||||
}
|
}
|
||||||
|
|
@ -241,7 +239,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** print (as a tree) */
|
/** print (as a tree) */
|
||||||
void print(const std::string& s,
|
void print(const std::string& s,
|
||||||
const std::function<std::string(L)>& formatter) const override {
|
const FormatterFunc& formatter) const override {
|
||||||
std::cout << s << " Choice(";
|
std::cout << s << " Choice(";
|
||||||
std::cout << formatter(label_) << ") " << std::endl;
|
std::cout << formatter(label_) << ") " << std::endl;
|
||||||
for (size_t i = 0; i < branches_.size(); i++)
|
for (size_t i = 0; i < branches_.size(); i++)
|
||||||
|
|
@ -284,8 +282,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** equality up to tolerance */
|
/** equality up to tolerance */
|
||||||
bool equals(const Node& q, double tol,
|
bool equals(const Node& q, double tol,
|
||||||
const std::function<bool(const Y&, const Y&, double)>&
|
const ComparatorFunc& comparator) const override {
|
||||||
comparator) const override {
|
|
||||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||||
if (!other) return false;
|
if (!other) return false;
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
|
|
@ -651,16 +648,14 @@ namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol,
|
||||||
const DecisionTree& other, double tol,
|
const ComparatorFunc& comparator) const {
|
||||||
const std::function<bool(const Y&, const Y&, double)>& comparator) const {
|
|
||||||
return root_->equals(*other.root_, tol, comparator);
|
return root_->equals(*other.root_, tol, comparator);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::print(
|
void DecisionTree<L, Y>::print(const std::string& s,
|
||||||
const std::string& s,
|
const FormatterFunc& formatter) const {
|
||||||
const std::function<std::string(L)>& formatter) const {
|
|
||||||
root_->print(s, formatter);
|
root_->print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,9 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
using FormatterFunc = std::function<std::string(L)>;
|
||||||
|
using ComparatorFunc = std::function<bool(const Y&, const Y&, double)>;
|
||||||
|
|
||||||
/** Handy typedefs for unary and binary function types */
|
/** Handy typedefs for unary and binary function types */
|
||||||
typedef std::function<Y(const Y&)> Unary;
|
typedef std::function<Y(const Y&)> Unary;
|
||||||
typedef std::function<Y(const Y&, const Y&)> Binary;
|
typedef std::function<Y(const Y&, const Y&)> Binary;
|
||||||
|
|
@ -91,15 +94,15 @@ namespace gtsam {
|
||||||
const void* id() const { return this; }
|
const void* id() const { return this; }
|
||||||
|
|
||||||
// everything else is virtual, no documentation here as internal
|
// everything else is virtual, no documentation here as internal
|
||||||
virtual void print(const std::string& s = "",
|
virtual void print(
|
||||||
const std::function<std::string(L)>& formatter =
|
const std::string& s = "",
|
||||||
&DefaultFormatter) const = 0;
|
const FormatterFunc& formatter = &DefaultFormatter) const = 0;
|
||||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
||||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||||
virtual bool sameLeaf(const Node& q) const = 0;
|
virtual bool sameLeaf(const Node& q) const = 0;
|
||||||
virtual bool equals(const Node& other, double tol = 1e-9,
|
virtual bool equals(
|
||||||
const std::function<bool(const Y&, const Y&, double)>&
|
const Node& other, double tol = 1e-9,
|
||||||
comparator = &DefaultComparator) const = 0;
|
const ComparatorFunc& comparator = &DefaultComparator) const = 0;
|
||||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||||
virtual Ptr apply(const Unary& op) const = 0;
|
virtual Ptr apply(const Unary& op) const = 0;
|
||||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||||
|
|
@ -181,13 +184,11 @@ namespace gtsam {
|
||||||
|
|
||||||
/** GTSAM-style print */
|
/** GTSAM-style print */
|
||||||
void print(const std::string& s = "DecisionTree",
|
void print(const std::string& s = "DecisionTree",
|
||||||
const std::function<std::string(L)>& formatter =
|
const FormatterFunc& formatter = &DefaultFormatter) const;
|
||||||
&DefaultFormatter) const;
|
|
||||||
|
|
||||||
// Testable
|
// Testable
|
||||||
bool equals(const DecisionTree& other, double tol = 1e-9,
|
bool equals(const DecisionTree& other, double tol = 1e-9,
|
||||||
const std::function<bool(const Y&, const Y&, double)>&
|
const ComparatorFunc& comparator = &DefaultComparator) const;
|
||||||
comparator = &DefaultComparator) const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue