Merge pull request #1002 from borglab/feature/decision_tree_2

release/4.3a0
Frank Dellaert 2022-01-03 10:15:01 -05:00 committed by GitHub
commit 8a28ac2426
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 48 deletions

View File

@ -29,7 +29,12 @@ namespace gtsam {
*/ */
template<typename L> template<typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/// Default method used by `formatter` when printing. /**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) { static std::string DefaultFormatter(const L& x) {
std::stringstream ss; std::stringstream ss;
ss << x; ss << x;
@ -38,7 +43,7 @@ namespace gtsam {
public: public:
typedef DecisionTree<L, double> Super; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
@ -66,33 +71,33 @@ namespace gtsam {
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() :
Super(1.0) { Base(1.0) {
} }
AlgebraicDecisionTree(const Super& add) : AlgebraicDecisionTree(const Base& add) :
Super(add) { Base(add) {
} }
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) { Base(label, y1, y2) {
} }
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) { Base(labelC, y1, y2) {
} }
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end()); ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
@ -100,21 +105,27 @@ namespace gtsam {
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end()); ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) { Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
/** Convert */ /**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M> template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`.
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L { std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label); return map.at(label);
}; };
@ -143,18 +154,18 @@ namespace gtsam {
} }
/** sum out variable */ /** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add); return this->combine(labelC, &Ring::add);
} }
/// print method customized to node type `double`. /// print method customized to node type `double`.
void print(const std::string& s, void print(const std::string& s,
const typename Super::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.2g") % v).str();
}; };
Super::print(s, labelFormatter, valueFormatter); Base::print(s, labelFormatter, valueFormatter);
} }
/// Equality method customized to node type `double`. /// Equality method customized to node type `double`.
@ -163,7 +174,7 @@ namespace gtsam {
auto compare = [tol](double a, double b) { auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol; return std::abs(a - b) < tol;
}; };
return Super::equals(other, compare); return Base::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree // AlgebraicDecisionTree

View File

@ -82,13 +82,19 @@ namespace gtsam {
return compare(this->constant_, other->constant_); return compare(this->constant_, other->constant_);
} }
/** print */ /**
* @brief Print method.
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
} }
/** to graphviz file */ /** Write graphviz format to stream `os`. */
void dot(std::ostream& os, const LabelFormatter& labelFormatter, void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, const ValueFormatter& valueFormatter,
bool showZero) const override { bool showZero) const override {
@ -154,7 +160,7 @@ namespace gtsam {
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
typedef boost::shared_ptr<const Choice> ChoicePtr; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
@ -462,6 +468,7 @@ namespace gtsam {
template <typename X> template <typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X) { std::function<Y(const X&)> Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(Y_of_X, L_of_L); root_ = convertFrom<L, X>(Y_of_X, L_of_L);
} }
@ -594,11 +601,11 @@ namespace gtsam {
const typename DecisionTree<M, X>::NodePtr& f, const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M, std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
typedef DecisionTree<M, X> MX; using MX = DecisionTree<M, X>;
typedef typename MX::Leaf MXLeaf; using MXLeaf = typename MX::Leaf;
typedef typename MX::Choice MXChoice; using MXChoice = typename MX::Choice;
typedef typename MX::NodePtr MXNodePtr; using MXNodePtr = typename MX::NodePtr;
typedef DecisionTree<L, Y> LY; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf // If leaf, apply unary conversion "op" and create a unique leaf

View File

@ -39,6 +39,7 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class GTSAM_EXPORT DecisionTree {
protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
@ -51,11 +52,11 @@ namespace gtsam {
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
/** Handy typedefs for unary and binary function types */ /** Handy typedefs for unary and binary function types */
typedef std::function<Y(const Y&)> Unary; using Unary = std::function<Y(const Y&)>;
typedef std::function<Y(const Y&, const Y&)> Binary; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
typedef std::pair<L,size_t> LabelC; using LabelC = std::pair<L,size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; class Leaf;
@ -64,7 +65,7 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { class Node {
public: public:
typedef boost::shared_ptr<const Node> Ptr; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
static int nrNodes; static int nrNodes;
@ -111,9 +112,9 @@ namespace gtsam {
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr; using NodePtr = typename Node::Ptr;
/// a DecisionTree just contains the root. TODO(dellaert): make protected. /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
@ -122,7 +123,16 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/// Convert from a DecisionTree<M, X>. /**
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
*
* @tparam M The previous label type.
* @tparam X The previous node type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param L_of_M Functor to convert from label type M to type L.
* @param Y_of_X Functor to convert from node type X to type Y.
* @return NodePtr
*/
template <typename M, typename X> template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f, NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M, std::function<L(const M&)> L_of_M,
@ -159,12 +169,27 @@ namespace gtsam {
DecisionTree(const L& label, // DecisionTree(const L& label, //
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f0, const DecisionTree& f1);
/** Convert from a different type. */ /**
* @brief Convert from a different node type.
*
* @tparam X The previous node type.
* @param other The DecisionTree to convert from.
* @param Y_of_X Functor to convert from node type X to type Y.
*/
template <typename X> template <typename X>
DecisionTree(const DecisionTree<L, X>& other, DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X); std::function<Y(const X&)> Y_of_X);
/** Convert from a different type, also transate labels via map. */ /**
* @brief Convert from a different node type X to node type Y, also transate
* labels via map from type M to L.
*
* @tparam M Previous label type.
* @tparam X Previous node type.
* @param other The decision tree to convert.
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X> template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M, DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
std::function<Y(const X&)> Y_of_X); std::function<Y(const X&)> Y_of_X);
@ -173,7 +198,13 @@ namespace gtsam {
/// @name Testable /// @name Testable
/// @{ /// @{
/** GTSAM-style print */ /**
* @brief GTSAM-style print
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const; const ValueFormatter& valueFormatter) const;
@ -189,7 +220,7 @@ namespace gtsam {
virtual ~DecisionTree() { virtual ~DecisionTree() {
} }
/** empty tree? */ /// Check if tree is empty.
bool empty() const { return !root_; } bool empty() const { return !root_; }
/** equality */ /** equality */
@ -248,18 +279,21 @@ namespace gtsam {
/** free versions of apply */ /** free versions of apply */
/// Apply unary operator `op` to DecisionTree `f`.
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f, DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::Unary& op) { const typename DecisionTree<L, Y>::Unary& op) {
return f.apply(op); return f.apply(op);
} }
/// Apply unary operator `op` to DecisionTree `f` but with node type.
template<typename L, typename Y, typename X> template<typename L, typename Y, typename X>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f, DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const std::function<Y(const X&)>& op) { const std::function<Y(const X&)>& op) {
return f.apply(op); return f.apply(op);
} }
/// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f, DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const DecisionTree<L, Y>& g, const DecisionTree<L, Y>& g,

View File

@ -45,15 +45,6 @@ struct Crazy {
double b; double b;
}; };
// bool equals(const Crazy& other, double tol = 1e-12) const {
// return a == other.a && std::abs(b - other.b) < tol;
// }
// bool operator==(const Crazy& other) const {
// return this->equals(other);
// }
// };
struct CrazyDecisionTree : public DecisionTree<string, Crazy> { struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
/// print to stdout /// print to stdout
void print(const std::string& s = "") const { void print(const std::string& s = "") const {
@ -261,8 +252,6 @@ TEST(DT, conversion)
return y != 0; return y != 0;
}; };
BDT f2(f1, ordering, bool_of_int); BDT f2(f1, ordering, bool_of_int);
// f1.print("f1");
// f2.print("f2");
// create a value // create a value
Assignment<Label> x00, x01, x10, x11; Assignment<Label> x00, x01, x10, x11;