Merge pull request #1000 from borglab/feature/decison_tree
commit
15742270de
|
@ -28,11 +28,22 @@ namespace gtsam {
|
|||
* TODO: consider eliminating this class altogether?
|
||||
*/
|
||||
template<typename L>
|
||||
class AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
|
||||
/**
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
|
||||
*
|
||||
* @param x The value passed to format.
|
||||
* @return std::string
|
||||
*/
|
||||
static std::string DefaultFormatter(const L& x) {
|
||||
std::stringstream ss;
|
||||
ss << x;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
typedef DecisionTree<L, double> Super;
|
||||
using Base = DecisionTree<L, double>;
|
||||
|
||||
/** The Real ring with addition and multiplication */
|
||||
struct Ring {
|
||||
|
@ -60,33 +71,33 @@ namespace gtsam {
|
|||
};
|
||||
|
||||
AlgebraicDecisionTree() :
|
||||
Super(1.0) {
|
||||
Base(1.0) {
|
||||
}
|
||||
|
||||
AlgebraicDecisionTree(const Super& add) :
|
||||
Super(add) {
|
||||
AlgebraicDecisionTree(const Base& add) :
|
||||
Base(add) {
|
||||
}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
||||
Super(label, y1, y2) {
|
||||
Base(label, y1, y2) {
|
||||
}
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
|
||||
Super(labelC, y1, y2) {
|
||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
|
||||
Base(labelC, y1, y2) {
|
||||
}
|
||||
|
||||
/** Create from keys and vector table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
|
||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
|
||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
}
|
||||
|
||||
/** Create from keys and string table */
|
||||
AlgebraicDecisionTree //
|
||||
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
|
||||
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
|
||||
// Convert string to doubles
|
||||
std::vector<double> ys;
|
||||
std::istringstream iss(table);
|
||||
|
@ -94,23 +105,32 @@ namespace gtsam {
|
|||
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||
|
||||
// now call recursive Create
|
||||
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
||||
ys.end());
|
||||
}
|
||||
|
||||
/** Create a new function splitting on a variable */
|
||||
template<typename Iterator>
|
||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
||||
Super(nullptr) {
|
||||
Base(nullptr) {
|
||||
this->root_ = compose(begin, end, label);
|
||||
}
|
||||
|
||||
/** Convert */
|
||||
/**
|
||||
* Convert labels from type M to type L.
|
||||
*
|
||||
* @param other: The AlgebraicDecisionTree with label type M to convert.
|
||||
* @param map: Map from label type M to label type L.
|
||||
*/
|
||||
template<typename M>
|
||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||
const std::map<M, L>& map) {
|
||||
this->root_ = this->template convert<M, double>(other.root_, map,
|
||||
Ring::id);
|
||||
// Functor for label conversion so we can use `convertFrom`.
|
||||
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||
return map.at(label);
|
||||
};
|
||||
std::function<double(const double&)> op = Ring::id;
|
||||
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
|
||||
}
|
||||
|
||||
/** sum */
|
||||
|
@ -134,10 +154,28 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** sum out variable */
|
||||
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
|
||||
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
|
||||
return this->combine(labelC, &Ring::add);
|
||||
}
|
||||
|
||||
/// print method customized to value type `double`.
|
||||
void print(const std::string& s,
|
||||
const typename Base::LabelFormatter& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
auto valueFormatter = [](const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
};
|
||||
Base::print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
/// Equality method customized to value type `double`.
|
||||
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
|
||||
// lambda for comparison of two doubles upto some tolerance.
|
||||
auto compare = [tol](double a, double b) {
|
||||
return std::abs(a - b) < tol;
|
||||
};
|
||||
return Base::equals(other, compare);
|
||||
}
|
||||
};
|
||||
// AlgebraicDecisionTree
|
||||
|
||||
|
|
|
@ -20,21 +20,21 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <boost/noncopyable.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
#include <boost/tuple/tuple.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using boost::assign::operator+=;
|
||||
#include <boost/type_traits/has_dereference.hpp>
|
||||
#include <boost/unordered_set.hpp>
|
||||
#include <boost/noncopyable.hpp>
|
||||
|
||||
#include <list>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <sstream>
|
||||
|
||||
using boost::assign::operator+=;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/*********************************************************************************/
|
||||
|
@ -76,22 +76,31 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** equality up to tolerance */
|
||||
bool equals(const Node& q, double tol) const override {
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||
if (!other) return false;
|
||||
return std::abs(double(this->constant_ - other->constant_)) < tol;
|
||||
return compare(this->constant_, other->constant_);
|
||||
}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& s) const override {
|
||||
bool showZero = true;
|
||||
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
|
||||
/**
|
||||
* @brief Print method.
|
||||
*
|
||||
* @param s Prefix string.
|
||||
* @param labelFormatter Functor to format the labels of type L.
|
||||
* @param valueFormatter Functor to format the values of type Y.
|
||||
*/
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
}
|
||||
|
||||
/** to graphviz file */
|
||||
void dot(std::ostream& os, bool showZero) const override {
|
||||
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
|
||||
<< boost::format("%4.2g") % constant_
|
||||
/** Write graphviz format to stream `os`. */
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const override {
|
||||
std::string value = valueFormatter(constant_);
|
||||
if (showZero || value.compare("0"))
|
||||
os << "\"" << this->id() << "\" [label=\"" << value
|
||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
||||
}
|
||||
|
||||
|
@ -151,7 +160,7 @@ namespace gtsam {
|
|||
/** incremental allSame */
|
||||
size_t allSame_;
|
||||
|
||||
typedef boost::shared_ptr<const Choice> ChoicePtr;
|
||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
|
||||
|
@ -236,16 +245,19 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** print (as a tree) */
|
||||
void print(const std::string& s) const override {
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Choice(";
|
||||
// std::cout << this << ",";
|
||||
std::cout << label_ << ") " << std::endl;
|
||||
std::cout << labelFormatter(label_) << ") " << std::endl;
|
||||
for (size_t i = 0; i < branches_.size(); i++)
|
||||
branches_[i]->print((boost::format("%s %d") % s % i).str());
|
||||
branches_[i]->print((boost::format("%s %d") % s % i).str(),
|
||||
labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
/** output to graphviz (as a a graph) */
|
||||
void dot(std::ostream& os, bool showZero) const override {
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const override {
|
||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
||||
<< "\"]\n";
|
||||
size_t B = branches_.size();
|
||||
|
@ -255,7 +267,8 @@ namespace gtsam {
|
|||
// Check if zero
|
||||
if (!showZero) {
|
||||
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
|
||||
if (leaf && !leaf->constant()) continue;
|
||||
std::string value = valueFormatter(leaf->constant());
|
||||
if (leaf && value.compare("0")) continue;
|
||||
}
|
||||
|
||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||
|
@ -264,7 +277,7 @@ namespace gtsam {
|
|||
if (i > 1) os << " [style=bold]";
|
||||
}
|
||||
os << std::endl;
|
||||
branch->dot(os, showZero);
|
||||
branch->dot(os, labelFormatter, valueFormatter, showZero);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -278,15 +291,16 @@ namespace gtsam {
|
|||
return (q.isLeaf() && q.sameLeaf(*this));
|
||||
}
|
||||
|
||||
/** equality up to tolerance */
|
||||
bool equals(const Node& q, double tol) const override {
|
||||
/** equality */
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||
if (!other) return false;
|
||||
if (this->label_ != other->label_) return false;
|
||||
if (branches_.size() != other->branches_.size()) return false;
|
||||
// we don't care about shared pointers being equal here
|
||||
for (size_t i = 0; i < branches_.size(); i++)
|
||||
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
|
||||
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -449,12 +463,26 @@ namespace gtsam {
|
|||
root_ = compose(functions.begin(), functions.end(), label);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename X>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||
std::function<Y(const X&)> Y_of_X) {
|
||||
// Define functor for identity mapping of node label.
|
||||
auto L_of_L = [](const L& label) { return label; };
|
||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X>
|
||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||
const std::map<M, L>& map, std::function<Y(const X&)> op) {
|
||||
root_ = convert(other.root_, map, op);
|
||||
const std::map<M, L>& map,
|
||||
std::function<Y(const X&)> Y_of_X) {
|
||||
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||
return map.at(label);
|
||||
};
|
||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||
}
|
||||
|
||||
/*********************************************************************************/
|
||||
|
@ -569,34 +597,34 @@ namespace gtsam {
|
|||
/*********************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename M, typename X>
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
||||
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map,
|
||||
std::function<Y(const X&)> op) {
|
||||
|
||||
typedef DecisionTree<M, X> MX;
|
||||
typedef typename MX::Leaf MXLeaf;
|
||||
typedef typename MX::Choice MXChoice;
|
||||
typedef typename MX::NodePtr MXNodePtr;
|
||||
typedef DecisionTree<L, Y> LY;
|
||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||
const typename DecisionTree<M, X>::NodePtr& f,
|
||||
std::function<L(const M&)> L_of_M,
|
||||
std::function<Y(const X&)> Y_of_X) const {
|
||||
using MX = DecisionTree<M, X>;
|
||||
using MXLeaf = typename MX::Leaf;
|
||||
using MXChoice = typename MX::Choice;
|
||||
using MXNodePtr = typename MX::NodePtr;
|
||||
using LY = DecisionTree<L, Y>;
|
||||
|
||||
// ugliness below because apparently we can't have templated virtual functions
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
|
||||
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
||||
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
|
||||
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
|
||||
// Check if Choice
|
||||
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
|
||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::Convert: Invalid NodePtr");
|
||||
|
||||
// get new label
|
||||
M oldLabel = choice->label();
|
||||
L newLabel = map.at(oldLabel);
|
||||
const M oldLabel = choice->label();
|
||||
const L newLabel = L_of_M(oldLabel);
|
||||
|
||||
// put together via Shannon expansion otherwise not sorted.
|
||||
std::vector<LY> functions;
|
||||
for(const MXNodePtr& branch: choice->branches()) {
|
||||
LY converted(convert<M, X>(branch, map, op));
|
||||
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||
functions += converted;
|
||||
}
|
||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||
|
@ -604,13 +632,16 @@ namespace gtsam {
|
|||
|
||||
/*********************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
|
||||
return root_->equals(*other.root_, tol);
|
||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||
const CompareFunc& compare) const {
|
||||
return root_->equals(*other.root_, compare);
|
||||
}
|
||||
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::print(const std::string& s) const {
|
||||
root_->print(s);
|
||||
void DecisionTree<L, Y>::print(const std::string& s,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const {
|
||||
root_->print(s, labelFormatter, valueFormatter);
|
||||
}
|
||||
|
||||
template<typename L, typename Y>
|
||||
|
@ -625,6 +656,11 @@ namespace gtsam {
|
|||
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||
// It is unclear what should happen if tree is empty:
|
||||
if (empty()) {
|
||||
throw std::runtime_error(
|
||||
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||
}
|
||||
return DecisionTree(root_->apply(op));
|
||||
}
|
||||
|
||||
|
@ -632,6 +668,11 @@ namespace gtsam {
|
|||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||
const Binary& op) const {
|
||||
// It is unclear what should happen if either tree is empty:
|
||||
if (empty() || g.empty()) {
|
||||
throw std::runtime_error(
|
||||
"DecisionTree::apply(binary op) undefined for empty trees.");
|
||||
}
|
||||
// apply the operaton on the root of both diagrams
|
||||
NodePtr h = root_->apply_f_op_g(*g.root_, op);
|
||||
// create a new class with the resulting root "h"
|
||||
|
@ -661,25 +702,33 @@ namespace gtsam {
|
|||
|
||||
/*********************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
|
||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
os << "digraph G {\n";
|
||||
root_->dot(os, showZero);
|
||||
root_->dot(os, labelFormatter, valueFormatter, showZero);
|
||||
os << " [ordering=out]}" << std::endl;
|
||||
}
|
||||
|
||||
template <typename L, typename Y>
|
||||
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
|
||||
void DecisionTree<L, Y>::dot(const std::string& name,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
std::ofstream os((name + ".dot").c_str());
|
||||
dot(os, showZero);
|
||||
dot(os, labelFormatter, valueFormatter, showZero);
|
||||
int result = system(
|
||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
||||
}
|
||||
|
||||
template <typename L, typename Y>
|
||||
std::string DecisionTree<L, Y>::dot(bool showZero) const {
|
||||
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const {
|
||||
std::stringstream ss;
|
||||
dot(ss, showZero);
|
||||
dot(ss, labelFormatter, valueFormatter, showZero);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -20,13 +20,13 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
|
||||
#include <boost/function.hpp>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -39,14 +39,24 @@ namespace gtsam {
|
|||
template<typename L, typename Y>
|
||||
class GTSAM_EXPORT DecisionTree {
|
||||
|
||||
protected:
|
||||
/// Default method for comparison of two objects of type Y.
|
||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
using LabelFormatter = std::function<std::string(L)>;
|
||||
using ValueFormatter = std::function<std::string(Y)>;
|
||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||
|
||||
/** Handy typedefs for unary and binary function types */
|
||||
typedef std::function<Y(const Y&)> Unary;
|
||||
typedef std::function<Y(const Y&, const Y&)> Binary;
|
||||
using Unary = std::function<Y(const Y&)>;
|
||||
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||
|
||||
/** A label annotated with cardinality */
|
||||
typedef std::pair<L,size_t> LabelC;
|
||||
using LabelC = std::pair<L,size_t>;
|
||||
|
||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||
class Leaf;
|
||||
|
@ -55,7 +65,7 @@ namespace gtsam {
|
|||
/** ------------------------ Node base class --------------------------- */
|
||||
class Node {
|
||||
public:
|
||||
typedef boost::shared_ptr<const Node> Ptr;
|
||||
using Ptr = boost::shared_ptr<const Node>;
|
||||
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
static int nrNodes;
|
||||
|
@ -79,11 +89,16 @@ namespace gtsam {
|
|||
const void* id() const { return this; }
|
||||
|
||||
// everything else is virtual, no documentation here as internal
|
||||
virtual void print(const std::string& s = "") const = 0;
|
||||
virtual void dot(std::ostream& os, bool showZero) const = 0;
|
||||
virtual void print(const std::string& s,
|
||||
const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const = 0;
|
||||
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero) const = 0;
|
||||
virtual bool sameLeaf(const Leaf& q) const = 0;
|
||||
virtual bool sameLeaf(const Node& q) const = 0;
|
||||
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
|
||||
virtual bool equals(const Node& other, const CompareFunc& compare =
|
||||
&DefaultCompare) const = 0;
|
||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||
virtual Ptr apply(const Unary& op) const = 0;
|
||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||
|
@ -97,9 +112,9 @@ namespace gtsam {
|
|||
public:
|
||||
|
||||
/** A function is a shared pointer to the root of a DT */
|
||||
typedef typename Node::Ptr NodePtr;
|
||||
using NodePtr = typename Node::Ptr;
|
||||
|
||||
/* a DecisionTree just contains the root */
|
||||
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||
NodePtr root_;
|
||||
|
||||
protected:
|
||||
|
@ -108,19 +123,29 @@ namespace gtsam {
|
|||
template<typename It, typename ValueIt>
|
||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||
|
||||
/** Convert to a different type */
|
||||
template<typename M, typename X> NodePtr
|
||||
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M,
|
||||
L>& map, std::function<Y(const X&)> op);
|
||||
|
||||
/** Default constructor */
|
||||
DecisionTree();
|
||||
/**
|
||||
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||
*
|
||||
* @tparam M The previous label type.
|
||||
* @tparam X The previous value type.
|
||||
* @param f The node pointer to the root of the previous DecisionTree.
|
||||
* @param L_of_M Functor to convert from label type M to type L.
|
||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||
* @return NodePtr
|
||||
*/
|
||||
template <typename M, typename X>
|
||||
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||
std::function<L(const M&)> L_of_M,
|
||||
std::function<Y(const X&)> Y_of_X) const;
|
||||
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor (for serialization) */
|
||||
DecisionTree();
|
||||
|
||||
/** Create a constant */
|
||||
DecisionTree(const Y& y);
|
||||
|
||||
|
@ -144,20 +169,48 @@ namespace gtsam {
|
|||
DecisionTree(const L& label, //
|
||||
const DecisionTree& f0, const DecisionTree& f1);
|
||||
|
||||
/** Convert from a different type */
|
||||
/**
|
||||
* @brief Convert from a different value type.
|
||||
*
|
||||
* @tparam X The previous value type.
|
||||
* @param other The DecisionTree to convert from.
|
||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||
*/
|
||||
template <typename X>
|
||||
DecisionTree(const DecisionTree<L, X>& other,
|
||||
std::function<Y(const X&)> Y_of_X);
|
||||
|
||||
/**
|
||||
* @brief Convert from a different value type X to value type Y, also transate
|
||||
* labels via map from type M to L.
|
||||
*
|
||||
* @tparam M Previous label type.
|
||||
* @tparam X Previous value type.
|
||||
* @param other The decision tree to convert.
|
||||
* @param L_of_M Map from label type M to type L.
|
||||
* @param Y_of_X Functor to convert from type X to type Y.
|
||||
*/
|
||||
template <typename M, typename X>
|
||||
DecisionTree(const DecisionTree<M, X>& other,
|
||||
const std::map<M, L>& map, std::function<Y(const X&)> op);
|
||||
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
|
||||
std::function<Y(const X&)> Y_of_X);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** GTSAM-style print */
|
||||
void print(const std::string& s = "DecisionTree") const;
|
||||
/**
|
||||
* @brief GTSAM-style print
|
||||
*
|
||||
* @param s Prefix string.
|
||||
* @param labelFormatter Functor to format the node label.
|
||||
* @param valueFormatter Functor to format the node value.
|
||||
*/
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const;
|
||||
|
||||
// Testable
|
||||
bool equals(const DecisionTree& other, double tol = 1e-9) const;
|
||||
bool equals(const DecisionTree& other,
|
||||
const CompareFunc& compare = &DefaultCompare) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
|
@ -167,6 +220,9 @@ namespace gtsam {
|
|||
virtual ~DecisionTree() {
|
||||
}
|
||||
|
||||
/// Check if tree is empty.
|
||||
bool empty() const { return !root_; }
|
||||
|
||||
/** equality */
|
||||
bool operator==(const DecisionTree& q) const;
|
||||
|
||||
|
@ -195,13 +251,17 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void dot(std::ostream& os, bool showZero = true) const;
|
||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name, bool showZero = true) const;
|
||||
void dot(const std::string& name, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter, bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(bool showZero = true) const;
|
||||
std::string dot(const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
@ -219,13 +279,15 @@ namespace gtsam {
|
|||
|
||||
/** free versions of apply */
|
||||
|
||||
template<typename Y, typename L>
|
||||
/// Apply unary operator `op` to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
const typename DecisionTree<L, Y>::Unary& op) {
|
||||
return f.apply(op);
|
||||
}
|
||||
|
||||
template<typename Y, typename L>
|
||||
/// Apply binary operator `op` to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
const DecisionTree<L, Y>& g,
|
||||
const typename DecisionTree<L, Y>::Binary& op) {
|
||||
|
|
|
@ -153,6 +153,31 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
static std::string valueFormatter(const double& v) {
|
||||
return (boost::format("%4.2g") % v).str();
|
||||
}
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void DecisionTreeFactor::dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void DecisionTreeFactor::dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
|
||||
bool showZero) const {
|
||||
return Potentials::dot(keyFormatter, valueFormatter, showZero);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::string DecisionTreeFactor::markdown(
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
|
|
|
@ -178,6 +178,20 @@ namespace gtsam {
|
|||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
/** output to graphviz format, stream version */
|
||||
void dot(std::ostream& os,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format, open a file */
|
||||
void dot(const std::string& name,
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/** output to graphviz format string */
|
||||
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/// Render as markdown table.
|
||||
std::string markdown(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||
|
|
|
@ -55,7 +55,7 @@ void Potentials::print(const string& s, const KeyFormatter& formatter) const {
|
|||
for (const std::pair<const Key,size_t>& key : cardinalities_)
|
||||
cout << formatter(key.first) << ":" << key.second << ", ";
|
||||
cout << "}" << endl;
|
||||
ADT::print(" ");
|
||||
ADT::print(" ", formatter);
|
||||
}
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
|
|
|
@ -46,7 +46,9 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
|||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
string dot(bool showZero = false) const;
|
||||
string dot(
|
||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||
bool showZero = false) const;
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -136,8 +136,8 @@ ADT create(const Signature& signature) {
|
|||
ADT p(signature.discreteKeys(), signature.cpt());
|
||||
static size_t count = 0;
|
||||
const DiscreteKey& key = signature.key();
|
||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||
dot(p, dotfile);
|
||||
string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||
dot(p, DOTfile);
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -414,13 +414,13 @@ TEST(ADT, equality_noparser)
|
|||
// Check straight equality
|
||||
ADT pA1 = create(A % tableA);
|
||||
ADT pA2 = create(A % tableA);
|
||||
EXPECT(pA1 == pA2); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % tableB);
|
||||
ADT pAB1 = apply(pA1, pB, &mul);
|
||||
ADT pAB2 = apply(pB, pA1, &mul);
|
||||
EXPECT(pAB2 == pAB1);
|
||||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -431,13 +431,13 @@ TEST(ADT, equality_parser)
|
|||
// Check straight equality
|
||||
ADT pA1 = create(A % "80/20");
|
||||
ADT pA2 = create(A % "80/20");
|
||||
EXPECT(pA1 == pA2); // should be equal
|
||||
EXPECT(pA1.equals(pA2)); // should be equal
|
||||
|
||||
// Check equality after apply
|
||||
ADT pB = create(B % "60/40");
|
||||
ADT pAB1 = apply(pA1, pB, &mul);
|
||||
ADT pAB2 = apply(pB, pA1, &mul);
|
||||
EXPECT(pAB2 == pAB1);
|
||||
EXPECT(pAB2.equals(pAB1));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
|
|
|
@ -40,25 +40,69 @@ void dot(const T&f, const string& filename) {
|
|||
|
||||
#define DOT(x)(dot(x,#x))
|
||||
|
||||
struct Crazy { int a; double b; };
|
||||
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
|
||||
struct Crazy {
|
||||
int a;
|
||||
double b;
|
||||
};
|
||||
|
||||
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||
/// print to stdout
|
||||
void print(const std::string& s = "") const {
|
||||
auto keyFormatter = [](const std::string& s) { return s; };
|
||||
auto valueFormatter = [](const Crazy& v) {
|
||||
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
|
||||
};
|
||||
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
|
||||
}
|
||||
/// Equality method customized to Crazy node type
|
||||
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
||||
auto compare = [tol](const Crazy& v, const Crazy& w) {
|
||||
return v.a == w.a && std::abs(v.b - w.b) < tol;
|
||||
};
|
||||
return DecisionTree<string, Crazy>::equals(other, compare);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||
}
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// Test string labels and int range
|
||||
/* ******************************************************************************** */
|
||||
|
||||
typedef DecisionTree<string, int> DT;
|
||||
struct DT : public DecisionTree<string, int> {
|
||||
using Base = DecisionTree<string, int>;
|
||||
using DecisionTree::DecisionTree;
|
||||
DT() = default;
|
||||
|
||||
DT(const Base& dt) : Base(dt) {}
|
||||
|
||||
/// print to stdout
|
||||
void print(const std::string& s = "") const {
|
||||
auto keyFormatter = [](const std::string& s) { return s; };
|
||||
auto valueFormatter = [](const int& v) {
|
||||
return (boost::format("%d") % v).str();
|
||||
};
|
||||
Base::print("", keyFormatter, valueFormatter);
|
||||
}
|
||||
/// Equality method customized to int node type
|
||||
bool equals(const Base& other, double tol = 1e-9) const {
|
||||
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||
return Base::equals(other, compare);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
namespace gtsam {
|
||||
template<> struct traits<DT> : public Testable<DT> {};
|
||||
}
|
||||
|
||||
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||
|
||||
struct Ring {
|
||||
static inline int zero() {
|
||||
return 0;
|
||||
|
@ -66,6 +110,9 @@ struct Ring {
|
|||
static inline int one() {
|
||||
return 1;
|
||||
}
|
||||
static inline int id(const int& a) {
|
||||
return a;
|
||||
}
|
||||
static inline int add(const int& a, const int& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
@ -88,6 +135,9 @@ TEST(DT, example)
|
|||
x10[A] = 1, x10[B] = 0;
|
||||
x11[A] = 1, x11[B] = 1;
|
||||
|
||||
// empty
|
||||
DT empty;
|
||||
|
||||
// A
|
||||
DT a(A, 0, 5);
|
||||
LONGS_EQUAL(0,a(x00))
|
||||
|
@ -106,6 +156,11 @@ TEST(DT, example)
|
|||
LONGS_EQUAL(5,notb(x10))
|
||||
DOT(notb);
|
||||
|
||||
// Check supplying empty trees yields an exception
|
||||
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||
|
||||
// apply, two nodes, in natural order
|
||||
DT anotb = apply(a, notb, &Ring::mul);
|
||||
LONGS_EQUAL(0,anotb(x00))
|
||||
|
@ -175,16 +230,37 @@ TEST(DT, example)
|
|||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// test Conversion
|
||||
// test Conversion of values
|
||||
std::function<bool(const int&)> bool_of_int = [](const int& y) {
|
||||
return y != 0;
|
||||
};
|
||||
typedef DecisionTree<string, bool> StringBoolTree;
|
||||
|
||||
TEST(DT, ConvertValuesOnly)
|
||||
{
|
||||
// Create labels
|
||||
string A("A"), B("B");
|
||||
|
||||
// apply, two nodes, in natural order
|
||||
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
|
||||
|
||||
// convert
|
||||
StringBoolTree f2(f1, bool_of_int);
|
||||
|
||||
// Check a value
|
||||
Assignment<string> x00;
|
||||
x00["A"] = 0, x00["B"] = 0;
|
||||
EXPECT(!f2(x00));
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// test Conversion of both values and labels.
|
||||
enum Label {
|
||||
U, V, X, Y, Z
|
||||
};
|
||||
typedef DecisionTree<Label, bool> BDT;
|
||||
bool convert(const int& y) {
|
||||
return y != 0;
|
||||
}
|
||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||
|
||||
TEST(DT, conversion)
|
||||
TEST(DT, ConvertBoth)
|
||||
{
|
||||
// Create labels
|
||||
string A("A"), B("B");
|
||||
|
@ -196,12 +272,9 @@ TEST(DT, conversion)
|
|||
map<string, Label> ordering;
|
||||
ordering[A] = X;
|
||||
ordering[B] = Y;
|
||||
std::function<bool(const int&)> op = convert;
|
||||
BDT f2(f1, ordering, op);
|
||||
// f1.print("f1");
|
||||
// f2.print("f2");
|
||||
LabelBoolTree f2(f1, ordering, bool_of_int);
|
||||
|
||||
// create a value
|
||||
// Check some values
|
||||
Assignment<Label> x00, x01, x10, x11;
|
||||
x00[X] = 0, x00[Y] = 0;
|
||||
x01[X] = 0, x01[Y] = 1;
|
||||
|
|
Loading…
Reference in New Issue