Move DefaultFormatter to base class and add defaults. Also replace Super with Base and add using.
parent
636425404d
commit
a9b2c32669
|
|
@ -29,16 +29,9 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
template<typename L>
|
template<typename L>
|
||||||
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||||
/// Default method used by `formatter` when printing.
|
|
||||||
static std::string DefaultFormatter(const L& x) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << x;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef DecisionTree<L, double> Super;
|
using Base = DecisionTree<L, double>;
|
||||||
|
|
||||||
/** The Real ring with addition and multiplication */
|
/** The Real ring with addition and multiplication */
|
||||||
struct Ring {
|
struct Ring {
|
||||||
|
|
@ -66,33 +59,33 @@ namespace gtsam {
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree() :
|
AlgebraicDecisionTree() :
|
||||||
Super(1.0) {
|
Base(1.0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
AlgebraicDecisionTree(const Super& add) :
|
AlgebraicDecisionTree(const Base& add) :
|
||||||
Super(add) {
|
Base(add) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
||||||
Super(label, y1, y2) {
|
Base(label, y1, y2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
|
||||||
Super(labelC, y1, y2) {
|
Base(labelC, y1, y2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/** Create from keys and vector table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
|
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
|
||||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||||
ys.end());
|
ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/** Create from keys and string table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
|
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
|
||||||
// Convert string to doubles
|
// Convert string to doubles
|
||||||
std::vector<double> ys;
|
std::vector<double> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
|
@ -100,18 +93,23 @@ namespace gtsam {
|
||||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||||
|
|
||||||
// now call recursive Create
|
// now call recursive Create
|
||||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||||
ys.end());
|
ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/** Create a new function splitting on a variable */
|
||||||
template<typename Iterator>
|
template<typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
||||||
Super(nullptr) {
|
Base(nullptr) {
|
||||||
this->root_ = compose(begin, end, label);
|
this->root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert */
|
/**
|
||||||
|
* Convert labels from type M to type L.
|
||||||
|
*
|
||||||
|
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||||
|
* @param map: Map from label type M to label type L.
|
||||||
|
*/
|
||||||
template<typename M>
|
template<typename M>
|
||||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||||
const std::map<M, L>& map) {
|
const std::map<M, L>& map) {
|
||||||
|
|
@ -143,18 +141,18 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** sum out variable */
|
/** sum out variable */
|
||||||
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
|
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
|
||||||
return this->combine(labelC, &Ring::add);
|
return this->combine(labelC, &Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// print method customized to node type `double`.
|
/// print method customized to node type `double`.
|
||||||
void print(const std::string& s,
|
void print(const std::string& s,
|
||||||
const typename Super::LabelFormatter& labelFormatter =
|
const typename Base::LabelFormatter& labelFormatter =
|
||||||
&DefaultFormatter) const {
|
&Base::DefaultFormatter) const {
|
||||||
auto valueFormatter = [](const double& v) {
|
auto valueFormatter = [](const double& v) {
|
||||||
return (boost::format("%4.2g") % v).str();
|
return (boost::format("%4.2g") % v).str();
|
||||||
};
|
};
|
||||||
Super::print(s, labelFormatter, valueFormatter);
|
Base::print(s, labelFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Equality method customized to node type `double`.
|
/// Equality method customized to node type `double`.
|
||||||
|
|
@ -163,7 +161,7 @@ namespace gtsam {
|
||||||
auto compare = [tol](double a, double b) {
|
auto compare = [tol](double a, double b) {
|
||||||
return std::abs(a - b) < tol;
|
return std::abs(a - b) < tol;
|
||||||
};
|
};
|
||||||
return Super::equals(other, compare);
|
return Base::equals(other, compare);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
// AlgebraicDecisionTree
|
||||||
|
|
|
||||||
|
|
@ -39,11 +39,24 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class GTSAM_EXPORT DecisionTree {
|
class GTSAM_EXPORT DecisionTree {
|
||||||
|
|
||||||
|
protected:
|
||||||
/// Default method for comparison of two objects of type Y.
|
/// Default method for comparison of two objects of type Y.
|
||||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
return a == b;
|
return a == b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
|
||||||
|
*
|
||||||
|
* @param x The value passed to format.
|
||||||
|
* @return std::string
|
||||||
|
*/
|
||||||
|
static std::string DefaultFormatter(const L& x) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << x;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using LabelFormatter = std::function<std::string(L)>;
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
|
|
@ -88,12 +101,14 @@ namespace gtsam {
|
||||||
const void* id() const { return this; }
|
const void* id() const { return this; }
|
||||||
|
|
||||||
// everything else is virtual, no documentation here as internal
|
// everything else is virtual, no documentation here as internal
|
||||||
virtual void print(const std::string& s,
|
virtual void print(
|
||||||
const LabelFormatter& labelFormatter,
|
const std::string& s,
|
||||||
const ValueFormatter& valueFormatter) const = 0;
|
const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
const ValueFormatter& valueFormatter = &DefaultFormatter) const = 0;
|
||||||
const ValueFormatter& valueFormatter,
|
virtual void dot(std::ostream& os,
|
||||||
bool showZero) const = 0;
|
const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
|
const ValueFormatter& valueFormatter = &DefaultFormatter,
|
||||||
|
bool showZero = true) const = 0;
|
||||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||||
virtual bool sameLeaf(const Node& q) const = 0;
|
virtual bool sameLeaf(const Node& q) const = 0;
|
||||||
virtual bool equals(const Node& other, const CompareFunc& compare =
|
virtual bool equals(const Node& other, const CompareFunc& compare =
|
||||||
|
|
@ -111,7 +126,7 @@ namespace gtsam {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
typedef typename Node::Ptr NodePtr;
|
using NodePtr = typename Node::Ptr;
|
||||||
|
|
||||||
/// a DecisionTree just contains the root. TODO(dellaert): make protected.
|
/// a DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
@ -164,7 +179,16 @@ namespace gtsam {
|
||||||
DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree(const DecisionTree<L, X>& other,
|
||||||
std::function<Y(const X&)> Y_of_X);
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
/** Convert from a different type, also transate labels via map. */
|
/**
|
||||||
|
* @brief Convert from a different node type X to node type Y, also transate
|
||||||
|
* labels via map from type M to L.
|
||||||
|
*
|
||||||
|
* @tparam M Previous label type.
|
||||||
|
* @tparam X Previous node type.
|
||||||
|
* @param other The decision tree to convert.
|
||||||
|
* @param L_of_M Map from label type M to type L.
|
||||||
|
* @param Y_of_X Functor to convert from type X to type Y.
|
||||||
|
*/
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
|
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
|
||||||
std::function<Y(const X&)> Y_of_X);
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
@ -173,9 +197,16 @@ namespace gtsam {
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** GTSAM-style print */
|
/**
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
* @brief GTSAM-style print
|
||||||
const ValueFormatter& valueFormatter) const;
|
*
|
||||||
|
* @param s Prefix string.
|
||||||
|
* @param labelFormatter Functor to format the node label.
|
||||||
|
* @param valueFormatter Functor to format the node value.
|
||||||
|
*/
|
||||||
|
void print(const std::string& s,
|
||||||
|
const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
|
const ValueFormatter& valueFormatter = &DefaultFormatter) const;
|
||||||
|
|
||||||
// Testable
|
// Testable
|
||||||
bool equals(const DecisionTree& other,
|
bool equals(const DecisionTree& other,
|
||||||
|
|
@ -220,16 +251,20 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** output to graphviz format, stream version */
|
/** output to graphviz format, stream version */
|
||||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
void dot(std::ostream& os,
|
||||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
|
const ValueFormatter& valueFormatter = &DefaultFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format, open a file */
|
/** output to graphviz format, open a file */
|
||||||
void dot(const std::string& name, const LabelFormatter& labelFormatter,
|
void dot(const std::string& name,
|
||||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
|
const ValueFormatter& valueFormatter = &DefaultFormatter,
|
||||||
|
bool showZero = true) const;
|
||||||
|
|
||||||
/** output to graphviz format string */
|
/** output to graphviz format string */
|
||||||
std::string dot(const LabelFormatter& labelFormatter,
|
std::string dot(const LabelFormatter& labelFormatter = &DefaultFormatter,
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter = &DefaultFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue