Move DefaultFormatter to base class and add defaults. Also replace Super with Base and add using.

release/4.3a0
Varun Agrawal 2022-01-02 23:45:01 -05:00
parent 636425404d
commit a9b2c32669
2 changed files with 75 additions and 42 deletions

View File

@ -29,16 +29,9 @@ namespace gtsam {
*/
template<typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/// Default method used by `formatter` when printing.
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public:
typedef DecisionTree<L, double> Super;
using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */
struct Ring {
@ -66,33 +59,33 @@ namespace gtsam {
};
AlgebraicDecisionTree() :
Super(1.0) {
Base(1.0) {
}
AlgebraicDecisionTree(const Super& add) :
Super(add) {
AlgebraicDecisionTree(const Base& add) :
Base(add) {
}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) {
Base(label, y1, y2) {
}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) {
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Base(labelC, y1, y2) {
}
/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}
/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
@ -100,18 +93,23 @@ namespace gtsam {
std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}
/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) {
Base(nullptr) {
this->root_ = compose(begin, end, label);
}
/** Convert */
/**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
@ -143,18 +141,18 @@ namespace gtsam {
}
/** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add);
}
/// print method customized to node type `double`.
void print(const std::string& s,
const typename Super::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
const typename Base::LabelFormatter& labelFormatter =
&Base::DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
};
Super::print(s, labelFormatter, valueFormatter);
Base::print(s, labelFormatter, valueFormatter);
}
/// Equality method customized to node type `double`.
@ -163,7 +161,7 @@ namespace gtsam {
auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol;
};
return Super::equals(other, compare);
return Base::equals(other, compare);
}
};
// AlgebraicDecisionTree

View File

@ -39,11 +39,24 @@ namespace gtsam {
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) {
return a == b;
}
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public:
using LabelFormatter = std::function<std::string(L)>;
@ -88,12 +101,14 @@ namespace gtsam {
const void* id() const { return this; }
// everything else is virtual, no documentation here as internal
virtual void print(const std::string& s,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const = 0;
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const = 0;
virtual void print(
const std::string& s,
const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter) const = 0;
virtual void dot(std::ostream& os,
const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter,
bool showZero = true) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0;
virtual bool equals(const Node& other, const CompareFunc& compare =
@ -111,7 +126,7 @@ namespace gtsam {
public:
/** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr;
using NodePtr = typename Node::Ptr;
/// a DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_;
@ -164,7 +179,16 @@ namespace gtsam {
DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X);
/** Convert from a different type, also transate labels via map. */
/**
* @brief Convert from a different node type X to node type Y, also transate
* labels via map from type M to L.
*
* @tparam M Previous label type.
* @tparam X Previous node type.
* @param other The decision tree to convert.
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
std::function<Y(const X&)> Y_of_X);
@ -173,9 +197,16 @@ namespace gtsam {
/// @name Testable
/// @{
/** GTSAM-style print */
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const;
/**
* @brief GTSAM-style print
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s,
const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter) const;
// Testable
bool equals(const DecisionTree& other,
@ -220,16 +251,20 @@ namespace gtsam {
}
/** output to graphviz format, stream version */
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
void dot(std::ostream& os,
const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
void dot(const std::string& name,
const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
std::string dot(const LabelFormatter& labelFormatter = &DefaultFormatter,
const ValueFormatter& valueFormatter = &DefaultFormatter,
bool showZero = true) const;
/// @name Advanced Interface