Merge pull request #1000 from borglab/feature/decison_tree

release/4.3a0
Frank Dellaert 2022-01-03 17:57:59 -05:00 committed by GitHub
commit 15742270de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 401 additions and 138 deletions

View File

@ -28,11 +28,22 @@ namespace gtsam {
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template<typename L>
class AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public: public:
typedef DecisionTree<L, double> Super; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
@ -60,33 +71,33 @@ namespace gtsam {
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() :
Super(1.0) { Base(1.0) {
} }
AlgebraicDecisionTree(const Super& add) : AlgebraicDecisionTree(const Base& add) :
Super(add) { Base(add) {
} }
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) { Base(label, y1, y2) {
} }
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) { Base(labelC, y1, y2) {
} }
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end()); ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
@ -94,23 +105,32 @@ namespace gtsam {
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end()); ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) { Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
/** Convert */ /**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M> template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
this->root_ = this->template convert<M, double>(other.root_, map, // Functor for label conversion so we can use `convertFrom`.
Ring::id); std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
} }
/** sum */ /** sum */
@ -134,10 +154,28 @@ namespace gtsam {
} }
/** sum out variable */ /** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add); return this->combine(labelC, &Ring::add);
} }
/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
/// Equality method customized to value type `double`.
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
// lambda for comparison of two doubles upto some tolerance.
auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol;
};
return Base::equals(other, compare);
}
}; };
// AlgebraicDecisionTree // AlgebraicDecisionTree

View File

@ -20,21 +20,21 @@
#pragma once #pragma once
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/base/Testable.h>
#include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/type_traits/has_dereference.hpp>
using boost::assign::operator+=;
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/noncopyable.hpp>
#include <list>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list>
#include <sstream> #include <sstream>
using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /*********************************************************************************/
@ -76,23 +76,32 @@ namespace gtsam {
} }
/** equality up to tolerance */ /** equality up to tolerance */
bool equals(const Node& q, double tol) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*> (&q); const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false; if (!other) return false;
return std::abs(double(this->constant_ - other->constant_)) < tol; return compare(this->constant_, other->constant_);
} }
/** print */ /**
void print(const std::string& s) const override { * @brief Print method.
bool showZero = true; *
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; * @param s Prefix string.
* @param labelFormatter Functor to format the labels of type L.
* @param valueFormatter Functor to format the values of type Y.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
} }
/** to graphviz file */ /** Write graphviz format to stream `os`. */
void dot(std::ostream& os, bool showZero) const override { void dot(std::ostream& os, const LabelFormatter& labelFormatter,
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" const ValueFormatter& valueFormatter,
<< boost::format("%4.2g") % constant_ bool showZero) const override {
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
} }
/** evaluate */ /** evaluate */
@ -151,7 +160,7 @@ namespace gtsam {
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
typedef boost::shared_ptr<const Choice> ChoicePtr; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
@ -236,16 +245,19 @@ namespace gtsam {
} }
/** print (as a tree) */ /** print (as a tree) */
void print(const std::string& s) const override { void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
// std::cout << this << ","; std::cout << labelFormatter(label_) << ") " << std::endl;
std::cout << label_ << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++) for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str()); branches_[i]->print((boost::format("%s %d") % s % i).str(),
labelFormatter, valueFormatter);
} }
/** output to graphviz (as a a graph) */ /** output to graphviz (as a a graph) */
void dot(std::ostream& os, bool showZero) const override { void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n"; << "\"]\n";
size_t B = branches_.size(); size_t B = branches_.size();
@ -255,7 +267,8 @@ namespace gtsam {
// Check if zero // Check if zero
if (!showZero) { if (!showZero) {
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get()); const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
if (leaf && !leaf->constant()) continue; std::string value = valueFormatter(leaf->constant());
if (leaf && value.compare("0")) continue;
} }
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
@ -264,7 +277,7 @@ namespace gtsam {
if (i > 1) os << " [style=bold]"; if (i > 1) os << " [style=bold]";
} }
os << std::endl; os << std::endl;
branch->dot(os, showZero); branch->dot(os, labelFormatter, valueFormatter, showZero);
} }
} }
@ -278,15 +291,16 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality up to tolerance */ /** equality */
bool equals(const Node& q, double tol) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*> (&q); const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false; if (!other) return false;
if (this->label_ != other->label_) return false; if (this->label_ != other->label_) return false;
if (branches_.size() != other->branches_.size()) return false; if (branches_.size() != other->branches_.size()) return false;
// we don't care about shared pointers being equal here // we don't care about shared pointers being equal here
for (size_t i = 0; i < branches_.size(); i++) for (size_t i = 0; i < branches_.size(); i++)
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
return false;
return true; return true;
} }
@ -450,11 +464,25 @@ namespace gtsam {
} }
/*********************************************************************************/ /*********************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
template<typename M, typename X> template <typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
}
/*********************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op) { const std::map<M, L>& map,
root_ = convert(other.root_, map, op); std::function<Y(const X&)> Y_of_X) {
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /*********************************************************************************/
@ -567,50 +595,53 @@ namespace gtsam {
} }
/*********************************************************************************/ /*********************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
template<typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map, const typename DecisionTree<M, X>::NodePtr& f,
std::function<Y(const X&)> op) { std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
typedef DecisionTree<M, X> MX; using MX = DecisionTree<M, X>;
typedef typename MX::Leaf MXLeaf; using MXLeaf = typename MX::Leaf;
typedef typename MX::Choice MXChoice; using MXChoice = typename MX::Choice;
typedef typename MX::NodePtr MXNodePtr; using MXNodePtr = typename MX::NodePtr;
typedef DecisionTree<L, Y> LY; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf // If leaf, apply unary conversion "op" and create a unique leaf
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get()); auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr"); "DecisionTree::Convert: Invalid NodePtr");
// get new label // get new label
M oldLabel = choice->label(); const M oldLabel = choice->label();
L newLabel = map.at(oldLabel); const L newLabel = L_of_M(oldLabel);
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(const MXNodePtr& branch: choice->branches()) { for(const MXNodePtr& branch: choice->branches()) {
LY converted(convert<M, X>(branch, map, op)); LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
functions += converted; functions += converted;
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /*********************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const { bool DecisionTree<L, Y>::equals(const DecisionTree& other,
return root_->equals(*other.root_, tol); const CompareFunc& compare) const {
return root_->equals(*other.root_, compare);
} }
template<typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s) const { void DecisionTree<L, Y>::print(const std::string& s,
root_->print(s); const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const {
root_->print(s, labelFormatter, valueFormatter);
} }
template<typename L, typename Y> template<typename L, typename Y>
@ -625,6 +656,11 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const { DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
if (empty()) {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
@ -632,6 +668,11 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
// It is unclear what should happen if either tree is empty:
if (empty() || g.empty()) {
throw std::runtime_error(
"DecisionTree::apply(binary op) undefined for empty trees.");
}
// apply the operaton on the root of both diagrams // apply the operaton on the root of both diagrams
NodePtr h = root_->apply_f_op_g(*g.root_, op); NodePtr h = root_->apply_f_op_g(*g.root_, op);
// create a new class with the resulting root "h" // create a new class with the resulting root "h"
@ -660,26 +701,34 @@ namespace gtsam {
} }
/*********************************************************************************/ /*********************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const { void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
os << "digraph G {\n"; os << "digraph G {\n";
root_->dot(os, showZero); root_->dot(os, labelFormatter, valueFormatter, showZero);
os << " [ordering=out]}" << std::endl; os << " [ordering=out]}" << std::endl;
} }
template<typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const { void DecisionTree<L, Y>::dot(const std::string& name,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
} }
template<typename L, typename Y> template <typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const { std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::stringstream ss; std::stringstream ss;
dot(ss, showZero); dot(ss, labelFormatter, valueFormatter, showZero);
return ss.str(); return ss.str();
} }

View File

@ -20,13 +20,13 @@
#pragma once #pragma once
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <boost/function.hpp> #include <boost/function.hpp>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector> #include <vector>
namespace gtsam { namespace gtsam {
@ -39,14 +39,24 @@ namespace gtsam {
template<typename L, typename Y> template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree { class GTSAM_EXPORT DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) {
return a == b;
}
public: public:
using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>;
/** Handy typedefs for unary and binary function types */ /** Handy typedefs for unary and binary function types */
typedef std::function<Y(const Y&)> Unary; using Unary = std::function<Y(const Y&)>;
typedef std::function<Y(const Y&, const Y&)> Binary; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
typedef std::pair<L,size_t> LabelC; using LabelC = std::pair<L,size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; class Leaf;
@ -55,7 +65,7 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { class Node {
public: public:
typedef boost::shared_ptr<const Node> Ptr; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
static int nrNodes; static int nrNodes;
@ -79,11 +89,16 @@ namespace gtsam {
const void* id() const { return this; } const void* id() const { return this; }
// everything else is virtual, no documentation here as internal // everything else is virtual, no documentation here as internal
virtual void print(const std::string& s = "") const = 0; virtual void print(const std::string& s,
virtual void dot(std::ostream& os, bool showZero) const = 0; const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const = 0;
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0; virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0; virtual bool sameLeaf(const Node& q) const = 0;
virtual bool equals(const Node& other, double tol = 1e-9) const = 0; virtual bool equals(const Node& other, const CompareFunc& compare =
&DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0; virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
@ -97,9 +112,9 @@ namespace gtsam {
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr; using NodePtr = typename Node::Ptr;
/* a DecisionTree just contains the root */ /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
@ -108,19 +123,29 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** Convert to a different type */ /**
template<typename M, typename X> NodePtr * @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, *
L>& map, std::function<Y(const X&)> op); * @tparam M The previous label type.
* @tparam X The previous value type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param L_of_M Functor to convert from label type M to type L.
* @param Y_of_X Functor to convert from value type X to type Y.
* @return NodePtr
*/
template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const;
/** Default constructor */ public:
DecisionTree();
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Default constructor (for serialization) */
DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); DecisionTree(const Y& y);
@ -144,20 +169,48 @@ namespace gtsam {
DecisionTree(const L& label, // DecisionTree(const L& label, //
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f0, const DecisionTree& f1);
/** Convert from a different type */ /**
template<typename M, typename X> * @brief Convert from a different value type.
DecisionTree(const DecisionTree<M, X>& other, *
const std::map<M, L>& map, std::function<Y(const X&)> op); * @tparam X The previous value type.
* @param other The DecisionTree to convert from.
* @param Y_of_X Functor to convert from value type X to type Y.
*/
template <typename X>
DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X);
/**
* @brief Convert from a different value type X to value type Y, also transate
* labels via map from type M to L.
*
* @tparam M Previous label type.
* @tparam X Previous value type.
* @param other The decision tree to convert.
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
std::function<Y(const X&)> Y_of_X);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
/** GTSAM-style print */ /**
void print(const std::string& s = "DecisionTree") const; * @brief GTSAM-style print
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const;
// Testable // Testable
bool equals(const DecisionTree& other, double tol = 1e-9) const; bool equals(const DecisionTree& other,
const CompareFunc& compare = &DefaultCompare) const;
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
@ -167,6 +220,9 @@ namespace gtsam {
virtual ~DecisionTree() { virtual ~DecisionTree() {
} }
/// Check if tree is empty.
bool empty() const { return !root_; }
/** equality */ /** equality */
bool operator==(const DecisionTree& q) const; bool operator==(const DecisionTree& q) const;
@ -195,13 +251,17 @@ namespace gtsam {
} }
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, bool showZero = true) const; void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const; void dot(const std::string& name, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(bool showZero = true) const; std::string dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero = true) const;
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -219,13 +279,15 @@ namespace gtsam {
/** free versions of apply */ /** free versions of apply */
template<typename Y, typename L> /// Apply unary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f, DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::Unary& op) { const typename DecisionTree<L, Y>::Unary& op) {
return f.apply(op); return f.apply(op);
} }
template<typename Y, typename L> /// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f, DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const DecisionTree<L, Y>& g, const DecisionTree<L, Y>& g,
const typename DecisionTree<L, Y>::Binary& op) { const typename DecisionTree<L, Y>::Binary& op) {

View File

@ -153,6 +153,31 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
/** output to graphviz format, stream version */
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero);
}
/* ************************************************************************* */ /* ************************************************************************* */
std::string DecisionTreeFactor::markdown( std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {

View File

@ -178,6 +178,20 @@ namespace gtsam {
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
/** output to graphviz format, stream version */
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/// Render as markdown table. /// Render as markdown table.
std::string markdown( std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

View File

@ -51,11 +51,11 @@ bool Potentials::equals(const Potentials& other, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const { void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: {"; cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_) for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", "; cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl; cout << "}" << endl;
ADT::print(" "); ADT::print(" ", formatter);
} }
// //
// /* ************************************************************************* */ // /* ************************************************************************* */

View File

@ -46,7 +46,9 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const; string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = false) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;

View File

@ -136,8 +136,8 @@ ADT create(const Signature& signature) {
ADT p(signature.discreteKeys(), signature.cpt()); ADT p(signature.discreteKeys(), signature.cpt());
static size_t count = 0; static size_t count = 0;
const DiscreteKey& key = signature.key(); const DiscreteKey& key = signature.key();
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
dot(p, dotfile); dot(p, DOTfile);
return p; return p;
} }
@ -414,13 +414,13 @@ TEST(ADT, equality_noparser)
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1 == pA2); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
ADT pAB1 = apply(pA1, pB, &mul); ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul); ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1); EXPECT(pAB2.equals(pAB1));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -431,13 +431,13 @@ TEST(ADT, equality_parser)
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1 == pA2); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
ADT pAB1 = apply(pA1, pB, &mul); ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul); ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ******************************************************************************** */

View File

@ -40,25 +40,69 @@ void dot(const T&f, const string& filename) {
#define DOT(x)(dot(x,#x)) #define DOT(x)(dot(x,#x))
struct Crazy { int a; double b; }; struct Crazy {
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be) int a;
double b;
};
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const Crazy& v) {
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
};
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to Crazy node type
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
auto compare = [tol](const Crazy& v, const Crazy& w) {
return v.a == w.a && std::abs(v.b - w.b) < tol;
};
return DecisionTree<string, Crazy>::equals(other, compare);
}
};
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} }
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ******************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ******************************************************************************** */
typedef DecisionTree<string, int> DT; struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>;
using DecisionTree::DecisionTree;
DT() = default;
DT(const Base& dt) : Base(dt) {}
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
return Base::equals(other, compare);
}
};
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template<> struct traits<DT> : public Testable<DT> {};
} }
GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() {
return 0; return 0;
@ -66,6 +110,9 @@ struct Ring {
static inline int one() { static inline int one() {
return 1; return 1;
} }
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) { static inline int add(const int& a, const int& b) {
return a + b; return a + b;
} }
@ -88,6 +135,9 @@ TEST(DT, example)
x10[A] = 1, x10[B] = 0; x10[A] = 1, x10[B] = 0;
x11[A] = 1, x11[B] = 1; x11[A] = 1, x11[B] = 1;
// empty
DT empty;
// A // A
DT a(A, 0, 5); DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0,a(x00))
@ -106,6 +156,11 @@ TEST(DT, example)
LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5,notb(x10))
DOT(notb); DOT(notb);
// Check supplying empty trees yields an exception
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
// apply, two nodes, in natural order // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0,anotb(x00))
@ -175,16 +230,37 @@ TEST(DT, example)
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
// test Conversion // test Conversion of values
std::function<bool(const int&)> bool_of_int = [](const int& y) {
return y != 0;
};
typedef DecisionTree<string, bool> StringBoolTree;
TEST(DT, ConvertValuesOnly)
{
// Create labels
string A("A"), B("B");
// apply, two nodes, in natural order
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
// convert
StringBoolTree f2(f1, bool_of_int);
// Check a value
Assignment<string> x00;
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00));
}
/* ******************************************************************************** */
// test Conversion of both values and labels.
enum Label { enum Label {
U, V, X, Y, Z U, V, X, Y, Z
}; };
typedef DecisionTree<Label, bool> BDT; typedef DecisionTree<Label, bool> LabelBoolTree;
bool convert(const int& y) {
return y != 0;
}
TEST(DT, conversion) TEST(DT, ConvertBoth)
{ {
// Create labels // Create labels
string A("A"), B("B"); string A("A"), B("B");
@ -196,12 +272,9 @@ TEST(DT, conversion)
map<string, Label> ordering; map<string, Label> ordering;
ordering[A] = X; ordering[A] = X;
ordering[B] = Y; ordering[B] = Y;
std::function<bool(const int&)> op = convert; LabelBoolTree f2(f1, ordering, bool_of_int);
BDT f2(f1, ordering, op);
// f1.print("f1");
// f2.print("f2");
// create a value // Check some values
Assignment<Label> x00, x01, x10, x11; Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0; x00[X] = 0, x00[Y] = 0;
x01[X] = 0, x01[Y] = 1; x01[X] = 0, x01[Y] = 1;