Fix most lint errors
parent
6aeb3db8f6
commit
d0ff3ab97e
|
@ -21,42 +21,44 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/noncopyable.hpp>
|
#include <boost/noncopyable.hpp>
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/type_traits/has_dereference.hpp>
|
#include <boost/type_traits/has_dereference.hpp>
|
||||||
#include <boost/unordered_set.hpp>
|
#include <boost/unordered_set.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using boost::assign::operator+=;
|
using boost::assign::operator+=;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Node
|
// Node
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Leaf
|
// Leaf
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** constant stored in this leaf */
|
/** constant stored in this leaf */
|
||||||
Y constant_;
|
Y constant_;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Constructor from constant */
|
/** Constructor from constant */
|
||||||
Leaf(const Y& constant) :
|
Leaf(const Y& constant) :
|
||||||
constant_(constant) {}
|
constant_(constant) {}
|
||||||
|
@ -96,7 +98,7 @@ namespace gtsam {
|
||||||
std::string value = valueFormatter(constant_);
|
std::string value = valueFormatter(constant_);
|
||||||
if (showZero || value.compare("0"))
|
if (showZero || value.compare("0"))
|
||||||
os << "\"" << this->id() << "\" [label=\"" << value
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
|
@ -136,15 +138,13 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Choice
|
// Choice
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** the label of the variable on which we split */
|
/** the label of the variable on which we split */
|
||||||
L label_;
|
L label_;
|
||||||
|
|
||||||
|
@ -158,10 +158,10 @@ namespace gtsam {
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
<< std::std::endl;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +172,8 @@ namespace gtsam {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
assert(f0->isLeaf());
|
assert(f0->isLeaf());
|
||||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
NodePtr newLeaf(
|
||||||
|
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
|
@ -192,7 +193,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||||
allSame_(true) {
|
allSame_(true) {
|
||||||
|
|
||||||
// Choose what to do based on label
|
// Choose what to do based on label
|
||||||
if (f.label() > g.label()) {
|
if (f.label() > g.label()) {
|
||||||
// f higher than g
|
// f higher than g
|
||||||
|
@ -318,10 +318,8 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||||
label_(label), allSame_(true) {
|
label_(label), allSame_(true) {
|
||||||
|
|
||||||
branches_.reserve(f.branches_.size()); // reserve space
|
branches_.reserve(f.branches_.size()); // reserve space
|
||||||
for (const NodePtr& branch: f.branches_)
|
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||||
push_back(branch->apply(op));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
|
@ -364,8 +362,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** choose a branch, recursively */
|
/** choose a branch, recursively */
|
||||||
NodePtr choose(const L& label, size_t index) const override {
|
NodePtr choose(const L& label, size_t index) const override {
|
||||||
if (label_ == label)
|
if (label_ == label) return branches_[index]; // choose branch
|
||||||
return branches_[index]; // choose branch
|
|
||||||
|
|
||||||
// second case, not label of interest, just recurse
|
// second case, not label of interest, just recurse
|
||||||
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||||
|
@ -373,12 +370,11 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {
|
DecisionTree<L, Y>::DecisionTree() {
|
||||||
}
|
}
|
||||||
|
@ -388,13 +384,13 @@ namespace gtsam {
|
||||||
root_(root) {
|
root_(root) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||||
root_ = NodePtr(new Leaf(y));
|
root_ = NodePtr(new Leaf(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
auto a = boost::make_shared<Choice>(label, 2);
|
auto a = boost::make_shared<Choice>(label, 2);
|
||||||
|
@ -404,7 +400,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||||
const Y& y2) {
|
const Y& y2) {
|
||||||
|
@ -417,7 +413,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::vector<Y>& ys) {
|
const std::vector<Y>& ys) {
|
||||||
|
@ -425,11 +421,10 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
|
||||||
// Convert std::string to values of type Y
|
// Convert std::string to values of type Y
|
||||||
std::vector<Y> ys;
|
std::vector<Y> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
@ -440,14 +435,14 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||||
Iterator begin, Iterator end, const L& label) {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
root_ = compose(begin, end, label);
|
root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||||
const DecisionTree& f0, const DecisionTree& f1) {
|
const DecisionTree& f0, const DecisionTree& f1) {
|
||||||
|
@ -456,7 +451,7 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
|
@ -466,7 +461,7 @@ namespace gtsam {
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X, typename Func>
|
template <typename M, typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
|
@ -475,16 +470,16 @@ namespace gtsam {
|
||||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Called by two constructors above.
|
// Called by two constructors above.
|
||||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
// Takes a label and a corresponding range of decision trees, and creates a
|
||||||
// decision tree. However, the order of the labels needs to be respected, so we
|
// new decision tree. However, the order of the labels needs to be respected,
|
||||||
// cannot just create a root Choice node on the label: if the label is not the
|
// so we cannot just create a root Choice node on the label: if the label is
|
||||||
// highest label, we need to do a complicated and expensive recursive call.
|
// not the highest label, we need a complicated/ expensive recursive call.
|
||||||
template<typename L, typename Y> template<typename Iterator>
|
template <typename L, typename Y>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
template <typename Iterator>
|
||||||
Iterator end, const L& label) const {
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
|
Iterator begin, Iterator end, const L& label) const {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
boost::optional<L> highestLabel;
|
boost::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
@ -527,7 +522,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// "create" is a bit of a complicated thing, but very useful.
|
// "create" is a bit of a complicated thing, but very useful.
|
||||||
// It takes a range of labels and a corresponding range of values,
|
// It takes a range of labels and a corresponding range of values,
|
||||||
// and creates a decision tree, as follows:
|
// and creates a decision tree, as follows:
|
||||||
|
@ -552,7 +547,6 @@ namespace gtsam {
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||||
|
|
||||||
// get crucial counts
|
// get crucial counts
|
||||||
size_t nrChoices = begin->second;
|
size_t nrChoices = begin->second;
|
||||||
size_t size = endY - beginY;
|
size_t size = endY - beginY;
|
||||||
|
@ -564,7 +558,11 @@ namespace gtsam {
|
||||||
// Create a simple choice node with values as leaves.
|
// Create a simple choice node with values as leaves.
|
||||||
if (size != nrChoices) {
|
if (size != nrChoices) {
|
||||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
std::cout << boost::format(
|
||||||
|
"DecisionTree::create: expected %d values but got %d "
|
||||||
|
"instead") %
|
||||||
|
nrChoices % size
|
||||||
|
<< std::endl;
|
||||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||||
}
|
}
|
||||||
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||||
|
@ -585,7 +583,7 @@ namespace gtsam {
|
||||||
return compose(functions.begin(), functions.end(), begin->first);
|
return compose(functions.begin(), functions.end(), begin->first);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
|
@ -594,8 +592,8 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
@ -618,12 +616,12 @@ namespace gtsam {
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit without Assignment<L> argument.
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct Visit {
|
struct Visit {
|
||||||
using F = std::function<void(const Y&)>;
|
using F = std::function<void(const Y&)>;
|
||||||
Visit(F f) : f(f) {} ///< Construct from folding function.
|
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
@ -647,13 +645,13 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit with Assignment<L> argument.
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct VisitWith {
|
struct VisitWith {
|
||||||
using Choices = Assignment<L>;
|
using Choices = Assignment<L>;
|
||||||
using F = std::function<void(const Choices&, const Y&)>;
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
Choices choices; ///< Assignment, mutating through recursion.
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
@ -681,7 +679,7 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Func, typename X>
|
template <typename Func, typename X>
|
||||||
|
@ -690,7 +688,7 @@ namespace gtsam {
|
||||||
return x0;
|
return x0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// labels is just done with a visit
|
// labels is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
@ -702,7 +700,7 @@ namespace gtsam {
|
||||||
return unique;
|
return unique;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
|
@ -736,7 +734,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||||
const Binary& op) const {
|
const Binary& op) const {
|
||||||
|
@ -752,7 +750,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// The way this works:
|
// The way this works:
|
||||||
// We have an ADT, picture it as a tree.
|
// We have an ADT, picture it as a tree.
|
||||||
// At a certain depth, we have a branch on "label".
|
// At a certain depth, we have a branch on "label".
|
||||||
|
@ -772,7 +770,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||||
const LabelFormatter& labelFormatter,
|
const LabelFormatter& labelFormatter,
|
||||||
|
@ -790,9 +788,11 @@ namespace gtsam {
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
std::ofstream os((name + ".dot").c_str());
|
std::ofstream os((name + ".dot").c_str());
|
||||||
dot(os, labelFormatter, valueFormatter, showZero);
|
dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
int result = system(
|
int result =
|
||||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
.c_str());
|
||||||
|
if (result == -1)
|
||||||
|
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
@ -804,8 +804,6 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,11 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -39,7 +41,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree {
|
class DecisionTree {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Default method for comparison of two objects of type Y.
|
/// Default method for comparison of two objects of type Y.
|
||||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
|
@ -47,7 +48,6 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using LabelFormatter = std::function<std::string(L)>;
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
using ValueFormatter = std::function<std::string(Y)>;
|
using ValueFormatter = std::function<std::string(Y)>;
|
||||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||||
|
@ -60,12 +60,11 @@ namespace gtsam {
|
||||||
using LabelC = std::pair<L, size_t>;
|
using LabelC = std::pair<L, size_t>;
|
||||||
|
|
||||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||||
class Leaf;
|
struct Leaf;
|
||||||
class Choice;
|
struct Choice;
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
class Node {
|
struct Node {
|
||||||
public:
|
|
||||||
using Ptr = boost::shared_ptr<const Node>;
|
using Ptr = boost::shared_ptr<const Node>;
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
|
@ -75,14 +74,16 @@ namespace gtsam {
|
||||||
// Constructor
|
// Constructor
|
||||||
Node() {
|
Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor
|
// Destructor
|
||||||
virtual ~Node() {
|
virtual ~Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +112,6 @@ namespace gtsam {
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
using NodePtr = typename Node::Ptr;
|
using NodePtr = typename Node::Ptr;
|
||||||
|
|
||||||
|
@ -119,8 +119,9 @@ namespace gtsam {
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/** Internal recursive function to create from keys, cardinalities,
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
* and Y values
|
||||||
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
|
@ -140,7 +141,6 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const;
|
std::function<Y(const X&)> Y_of_X) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -148,7 +148,7 @@ namespace gtsam {
|
||||||
DecisionTree();
|
DecisionTree();
|
||||||
|
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
@ -167,8 +167,8 @@ namespace gtsam {
|
||||||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/** Create DecisionTree from two others */
|
/** Create DecisionTree from two others */
|
||||||
DecisionTree(const L& label, //
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f0, const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type.
|
* @brief Convert from a different value type.
|
||||||
|
@ -289,7 +289,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
|
const Binary& op) const;
|
||||||
|
|
||||||
/** combine with LabelC for convenience */
|
/** combine with LabelC for convenience */
|
||||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||||
|
@ -313,14 +314,13 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
DecisionTree(const NodePtr& root);
|
explicit DecisionTree(const NodePtr& root);
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
@ -340,11 +340,19 @@ namespace gtsam {
|
||||||
return f.apply(g, op);
|
return f.apply(g, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unzip a DecisionTree if its leaves are `std::pair`
|
/**
|
||||||
|
* @brief unzip a DecisionTree with `std::pair` values.
|
||||||
|
*
|
||||||
|
* @param input the DecisionTree with `(T1,T2)` values.
|
||||||
|
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||||
|
*/
|
||||||
template <typename L, typename T1, typename T2>
|
template <typename L, typename T1, typename T2>
|
||||||
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
|
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||||
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||||
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
|
return std::make_pair(
|
||||||
|
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||||
|
DecisionTree<L, T2>(input,
|
||||||
|
[](std::pair<T1, T2> i) { return i.second; }));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue