Merge pull request #1000 from borglab/feature/decison_tree

release/4.3a0
Frank Dellaert 2022-01-03 17:57:59 -05:00 committed by GitHub
commit 15742270de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 401 additions and 138 deletions

View File

@ -28,11 +28,22 @@ namespace gtsam {
* TODO: consider eliminating this class altogether?
*/
template<typename L>
class AlgebraicDecisionTree: public DecisionTree<L, double> {
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @param x The value passed to format.
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
ss << x;
return ss.str();
}
public:
typedef DecisionTree<L, double> Super;
using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */
struct Ring {
@ -60,33 +71,33 @@ namespace gtsam {
};
AlgebraicDecisionTree() :
Super(1.0) {
Base(1.0) {
}
AlgebraicDecisionTree(const Super& add) :
Super(add) {
AlgebraicDecisionTree(const Base& add) :
Base(add) {
}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Super(label, y1, y2) {
Base(label, y1, y2) {
}
/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) :
Super(labelC, y1, y2) {
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Base(labelC, y1, y2) {
}
/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}
/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Super::LabelC>& labelCs, const std::string& table) {
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
@ -94,23 +105,32 @@ namespace gtsam {
std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create
this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(),
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
}
/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Super(nullptr) {
Base(nullptr) {
this->root_ = compose(begin, end, label);
}
/** Convert */
/**
* Convert labels from type M to type L.
*
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
this->root_ = this->template convert<M, double>(other.root_, map,
Ring::id);
// Functor for label conversion so we can use `convertFrom`.
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
}
/** sum */
@ -134,10 +154,28 @@ namespace gtsam {
}
/** sum out variable */
AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const {
AlgebraicDecisionTree sum(const typename Base::LabelC& labelC) const {
return this->combine(labelC, &Ring::add);
}
/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
/// Equality method customized to value type `double`.
bool equals(const AlgebraicDecisionTree& other, double tol = 1e-9) const {
// lambda for comparison of two doubles upto some tolerance.
auto compare = [tol](double a, double b) {
return std::abs(a - b) < tol;
};
return Base::equals(other, compare);
}
};
// AlgebraicDecisionTree

View File

@ -20,21 +20,21 @@
#pragma once
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/base/Testable.h>
#include <boost/assign/std/vector.hpp>
#include <boost/format.hpp>
#include <boost/noncopyable.hpp>
#include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp>
#include <boost/assign/std/vector.hpp>
using boost::assign::operator+=;
#include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp>
#include <boost/noncopyable.hpp>
#include <list>
#include <cmath>
#include <fstream>
#include <list>
#include <sstream>
using boost::assign::operator+=;
namespace gtsam {
/*********************************************************************************/
@ -76,22 +76,31 @@ namespace gtsam {
}
/** equality up to tolerance */
bool equals(const Node& q, double tol) const override {
const Leaf* other = dynamic_cast<const Leaf*> (&q);
bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false;
return std::abs(double(this->constant_ - other->constant_)) < tol;
return compare(this->constant_, other->constant_);
}
/** print */
void print(const std::string& s) const override {
bool showZero = true;
if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl;
/**
* @brief Print method.
*
* @param s Prefix string.
* @param labelFormatter Functor to format the labels of type L.
* @param valueFormatter Functor to format the values of type Y.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
}
/** to graphviz file */
void dot(std::ostream& os, bool showZero) const override {
if (showZero || constant_) os << "\"" << this->id() << "\" [label=\""
<< boost::format("%4.2g") % constant_
/** Write graphviz format to stream `os`. */
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
}
@ -151,7 +160,7 @@ namespace gtsam {
/** incremental allSame */
size_t allSame_;
typedef boost::shared_ptr<const Choice> ChoicePtr;
using ChoicePtr = boost::shared_ptr<const Choice>;
public:
@ -236,16 +245,19 @@ namespace gtsam {
}
/** print (as a tree) */
void print(const std::string& s) const override {
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
// std::cout << this << ",";
std::cout << label_ << ") " << std::endl;
std::cout << labelFormatter(label_) << ") " << std::endl;
for (size_t i = 0; i < branches_.size(); i++)
branches_[i]->print((boost::format("%s %d") % s % i).str());
branches_[i]->print((boost::format("%s %d") % s % i).str(),
labelFormatter, valueFormatter);
}
/** output to graphviz (as a a graph) */
void dot(std::ostream& os, bool showZero) const override {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n";
size_t B = branches_.size();
@ -255,7 +267,8 @@ namespace gtsam {
// Check if zero
if (!showZero) {
const Leaf* leaf = dynamic_cast<const Leaf*> (branch.get());
if (leaf && !leaf->constant()) continue;
std::string value = valueFormatter(leaf->constant());
if (leaf && value.compare("0")) continue;
}
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
@ -264,7 +277,7 @@ namespace gtsam {
if (i > 1) os << " [style=bold]";
}
os << std::endl;
branch->dot(os, showZero);
branch->dot(os, labelFormatter, valueFormatter, showZero);
}
}
@ -278,15 +291,16 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality up to tolerance */
bool equals(const Node& q, double tol) const override {
const Choice* other = dynamic_cast<const Choice*> (&q);
/** equality */
bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false;
if (this->label_ != other->label_) return false;
if (branches_.size() != other->branches_.size()) return false;
// we don't care about shared pointers being equal here
for (size_t i = 0; i < branches_.size(); i++)
if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false;
if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
return false;
return true;
}
@ -450,11 +464,25 @@ namespace gtsam {
}
/*********************************************************************************/
template<typename L, typename Y>
template<typename M, typename X>
template <typename L, typename Y>
template <typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
}
/*********************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op) {
root_ = convert(other.root_, map, op);
const std::map<M, L>& map,
std::function<Y(const X&)> Y_of_X) {
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label);
};
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
}
/*********************************************************************************/
@ -567,50 +595,53 @@ namespace gtsam {
}
/*********************************************************************************/
template<typename L, typename Y>
template<typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
const typename DecisionTree<M, X>::NodePtr& f, const std::map<M, L>& map,
std::function<Y(const X&)> op) {
typedef DecisionTree<M, X> MX;
typedef typename MX::Leaf MXLeaf;
typedef typename MX::Choice MXChoice;
typedef typename MX::NodePtr MXNodePtr;
typedef DecisionTree<L, Y> LY;
template <typename L, typename Y>
template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
using MX = DecisionTree<M, X>;
using MXLeaf = typename MX::Leaf;
using MXChoice = typename MX::Choice;
using MXNodePtr = typename MX::NodePtr;
using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf
const MXLeaf* leaf = dynamic_cast<const MXLeaf*> (f.get());
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice
boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
// get new label
M oldLabel = choice->label();
L newLabel = map.at(oldLabel);
const M oldLabel = choice->label();
const L newLabel = L_of_M(oldLabel);
// put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions;
for(const MXNodePtr& branch: choice->branches()) {
LY converted(convert<M, X>(branch, map, op));
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
functions += converted;
}
return LY::compose(functions.begin(), functions.end(), newLabel);
}
/*********************************************************************************/
template<typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, double tol) const {
return root_->equals(*other.root_, tol);
template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const {
return root_->equals(*other.root_, compare);
}
template<typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s) const {
root_->print(s);
template <typename L, typename Y>
void DecisionTree<L, Y>::print(const std::string& s,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const {
root_->print(s, labelFormatter, valueFormatter);
}
template<typename L, typename Y>
@ -625,6 +656,11 @@ namespace gtsam {
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
if (empty()) {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
return DecisionTree(root_->apply(op));
}
@ -632,6 +668,11 @@ namespace gtsam {
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const {
// It is unclear what should happen if either tree is empty:
if (empty() || g.empty()) {
throw std::runtime_error(
"DecisionTree::apply(binary op) undefined for empty trees.");
}
// apply the operaton on the root of both diagrams
NodePtr h = root_->apply_f_op_g(*g.root_, op);
// create a new class with the resulting root "h"
@ -660,26 +701,34 @@ namespace gtsam {
}
/*********************************************************************************/
template<typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, bool showZero) const {
template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
os << "digraph G {\n";
root_->dot(os, showZero);
root_->dot(os, labelFormatter, valueFormatter, showZero);
os << " [ordering=out]}" << std::endl;
}
template<typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name, bool showZero) const {
template <typename L, typename Y>
void DecisionTree<L, Y>::dot(const std::string& name,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::ofstream os((name + ".dot").c_str());
dot(os, showZero);
dot(os, labelFormatter, valueFormatter, showZero);
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
}
template<typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const {
template <typename L, typename Y>
std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const {
std::stringstream ss;
dot(ss, showZero);
dot(ss, labelFormatter, valueFormatter, showZero);
return ss.str();
}

View File

@ -20,13 +20,13 @@
#pragma once
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>
#include <boost/function.hpp>
#include <functional>
#include <iostream>
#include <map>
#include <sstream>
#include <vector>
namespace gtsam {
@ -39,14 +39,24 @@ namespace gtsam {
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
protected:
/// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) {
return a == b;
}
public:
using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>;
/** Handy typedefs for unary and binary function types */
typedef std::function<Y(const Y&)> Unary;
typedef std::function<Y(const Y&, const Y&)> Binary;
using Unary = std::function<Y(const Y&)>;
using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */
typedef std::pair<L,size_t> LabelC;
using LabelC = std::pair<L,size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf;
@ -55,7 +65,7 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */
class Node {
public:
typedef boost::shared_ptr<const Node> Ptr;
using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY
static int nrNodes;
@ -79,11 +89,16 @@ namespace gtsam {
const void* id() const { return this; }
// everything else is virtual, no documentation here as internal
virtual void print(const std::string& s = "") const = 0;
virtual void dot(std::ostream& os, bool showZero) const = 0;
virtual void print(const std::string& s,
const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const = 0;
virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const = 0;
virtual bool sameLeaf(const Leaf& q) const = 0;
virtual bool sameLeaf(const Node& q) const = 0;
virtual bool equals(const Node& other, double tol = 1e-9) const = 0;
virtual bool equals(const Node& other, const CompareFunc& compare =
&DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
@ -97,9 +112,9 @@ namespace gtsam {
public:
/** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr;
using NodePtr = typename Node::Ptr;
/* a DecisionTree just contains the root */
/// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_;
protected:
@ -108,19 +123,29 @@ namespace gtsam {
template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** Convert to a different type */
template<typename M, typename X> NodePtr
convert(const typename DecisionTree<M, X>::NodePtr& f, const std::map<M,
L>& map, std::function<Y(const X&)> op);
/** Default constructor */
DecisionTree();
/**
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
*
* @tparam M The previous label type.
* @tparam X The previous value type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param L_of_M Functor to convert from label type M to type L.
* @param Y_of_X Functor to convert from value type X to type Y.
* @return NodePtr
*/
template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const;
public:
/// @name Standard Constructors
/// @{
/** Default constructor (for serialization) */
DecisionTree();
/** Create a constant */
DecisionTree(const Y& y);
@ -144,20 +169,48 @@ namespace gtsam {
DecisionTree(const L& label, //
const DecisionTree& f0, const DecisionTree& f1);
/** Convert from a different type */
template<typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op);
/**
* @brief Convert from a different value type.
*
* @tparam X The previous value type.
* @param other The DecisionTree to convert from.
* @param Y_of_X Functor to convert from value type X to type Y.
*/
template <typename X>
DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> Y_of_X);
/**
* @brief Convert from a different value type X to value type Y, also transate
* labels via map from type M to L.
*
* @tparam M Previous label type.
* @tparam X Previous value type.
* @param other The decision tree to convert.
* @param L_of_M Map from label type M to type L.
* @param Y_of_X Functor to convert from type X to type Y.
*/
template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
std::function<Y(const X&)> Y_of_X);
/// @}
/// @name Testable
/// @{
/** GTSAM-style print */
void print(const std::string& s = "DecisionTree") const;
/**
* @brief GTSAM-style print
*
* @param s Prefix string.
* @param labelFormatter Functor to format the node label.
* @param valueFormatter Functor to format the node value.
*/
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const;
// Testable
bool equals(const DecisionTree& other, double tol = 1e-9) const;
bool equals(const DecisionTree& other,
const CompareFunc& compare = &DefaultCompare) const;
/// @}
/// @name Standard Interface
@ -167,6 +220,9 @@ namespace gtsam {
virtual ~DecisionTree() {
}
/// Check if tree is empty.
bool empty() const { return !root_; }
/** equality */
bool operator==(const DecisionTree& q) const;
@ -195,13 +251,17 @@ namespace gtsam {
}
/** output to graphviz format, stream version */
void dot(std::ostream& os, bool showZero = true) const;
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const;
void dot(const std::string& name, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, bool showZero = true) const;
/** output to graphviz format string */
std::string dot(bool showZero = true) const;
std::string dot(const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero = true) const;
/// @name Advanced Interface
/// @{
@ -219,13 +279,15 @@ namespace gtsam {
/** free versions of apply */
template<typename Y, typename L>
/// Apply unary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::Unary& op) {
return f.apply(op);
}
template<typename Y, typename L>
/// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const DecisionTree<L, Y>& g,
const typename DecisionTree<L, Y>::Binary& op) {

View File

@ -153,6 +153,31 @@ namespace gtsam {
return result;
}
/* ************************************************************************* */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
/** output to graphviz format, stream version */
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero);
}
/* ************************************************************************* */
std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const {

View File

@ -178,6 +178,20 @@ namespace gtsam {
/// @name Wrapper support
/// @{
/** output to graphviz format, stream version */
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format, open a file */
void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const;
/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

View File

@ -51,11 +51,11 @@ bool Potentials::equals(const Potentials& other, double tol) const {
/* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: {";
cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl;
ADT::print(" ");
ADT::print(" ", formatter);
}
//
// /* ************************************************************************* */

View File

@ -46,7 +46,9 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
bool showZero = false) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -136,8 +136,8 @@ ADT create(const Signature& signature) {
ADT p(signature.discreteKeys(), signature.cpt());
static size_t count = 0;
const DiscreteKey& key = signature.key();
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
dot(p, dotfile);
string DOTfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
dot(p, DOTfile);
return p;
}
@ -414,13 +414,13 @@ TEST(ADT, equality_noparser)
// Check straight equality
ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA);
EXPECT(pA1 == pA2); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % tableB);
ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1);
EXPECT(pAB2.equals(pAB1));
}
/* ************************************************************************* */
@ -431,13 +431,13 @@ TEST(ADT, equality_parser)
// Check straight equality
ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20");
EXPECT(pA1 == pA2); // should be equal
EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply
ADT pB = create(B % "60/40");
ADT pAB1 = apply(pA1, pB, &mul);
ADT pAB2 = apply(pB, pA1, &mul);
EXPECT(pAB2 == pAB1);
EXPECT(pAB2.equals(pAB1));
}
/* ******************************************************************************** */

View File

@ -40,25 +40,69 @@ void dot(const T&f, const string& filename) {
#define DOT(x)(dot(x,#x))
struct Crazy { int a; double b; };
typedef DecisionTree<string,Crazy> CrazyDecisionTree; // check that DecisionTree is actually generic (as it pretends to be)
struct Crazy {
int a;
double b;
};
struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const Crazy& v) {
return (boost::format("{%d,%4.2g}") % v.a % v.b).str();
};
DecisionTree<string, Crazy>::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to Crazy node type
bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
auto compare = [tol](const Crazy& v, const Crazy& w) {
return v.a == w.a && std::abs(v.b - w.b) < tol;
};
return DecisionTree<string, Crazy>::equals(other, compare);
}
};
// traits
namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
}
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */
// Test string labels and int range
/* ******************************************************************************** */
typedef DecisionTree<string, int> DT;
struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>;
using DecisionTree::DecisionTree;
DT() = default;
DT(const Base& dt) : Base(dt) {}
/// print to stdout
void print(const std::string& s = "") const {
auto keyFormatter = [](const std::string& s) { return s; };
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
return Base::equals(other, compare);
}
};
// traits
namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {};
}
GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring {
static inline int zero() {
return 0;
@ -66,6 +110,9 @@ struct Ring {
static inline int one() {
return 1;
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
@ -88,6 +135,9 @@ TEST(DT, example)
x10[A] = 1, x10[B] = 0;
x11[A] = 1, x11[B] = 1;
// empty
DT empty;
// A
DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00))
@ -106,6 +156,11 @@ TEST(DT, example)
LONGS_EQUAL(5,notb(x10))
DOT(notb);
// Check supplying empty trees yields an exception
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
// apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00))
@ -175,16 +230,37 @@ TEST(DT, example)
}
/* ******************************************************************************** */
// test Conversion
// test Conversion of values
std::function<bool(const int&)> bool_of_int = [](const int& y) {
return y != 0;
};
typedef DecisionTree<string, bool> StringBoolTree;
TEST(DT, ConvertValuesOnly)
{
// Create labels
string A("A"), B("B");
// apply, two nodes, in natural order
DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul);
// convert
StringBoolTree f2(f1, bool_of_int);
// Check a value
Assignment<string> x00;
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00));
}
/* ******************************************************************************** */
// test Conversion of both values and labels.
enum Label {
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> BDT;
bool convert(const int& y) {
return y != 0;
}
typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DT, conversion)
TEST(DT, ConvertBoth)
{
// Create labels
string A("A"), B("B");
@ -196,12 +272,9 @@ TEST(DT, conversion)
map<string, Label> ordering;
ordering[A] = X;
ordering[B] = Y;
std::function<bool(const int&)> op = convert;
BDT f2(f1, ordering, op);
// f1.print("f1");
// f2.print("f2");
LabelBoolTree f2(f1, ordering, bool_of_int);
// create a value
// Check some values
Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0;
x01[X] = 0, x01[Y] = 1;