Fix most lint errors

release/4.3a0
Frank Dellaert 2022-01-22 12:40:29 -05:00
parent 6aeb3db8f6
commit d0ff3ab97e
2 changed files with 121 additions and 115 deletions

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp> #include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <map>
#include <set>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=; using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /****************************************************************************/
// Node // Node
/*********************************************************************************/ /****************************************************************************/
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
template<typename L, typename Y> template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0; int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif #endif
/*********************************************************************************/ /****************************************************************************/
// Leaf // Leaf
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
public:
/** Constructor from constant */ /** Constructor from constant */
Leaf(const Y& constant) : Leaf(const Y& constant) :
constant_(constant) {} constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_); std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
} }
/** evaluate */ /** evaluate */
@ -136,15 +138,13 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
}; // Leaf }; // Leaf
/*********************************************************************************/ /****************************************************************************/
// Choice // Choice
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */ /** the label of the variable on which we split */
L label_; L label_;
@ -158,10 +158,10 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif #endif
} }
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf()); assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -192,7 +193,6 @@ namespace gtsam {
*/ */
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
if (f.label() > g.label()) { if (f.label() > g.label()) {
// f higher than g // f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/ */
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch: f.branches_) for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
push_back(branch->apply(op));
} }
/** apply unary operator */ /** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */ /** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
if (label_ == label) if (label_ == label) return branches_[index]; // choose branch
return branches_[index]; // choose branch
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size()); auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
}; // Choice }; // Choice
/*********************************************************************************/ /****************************************************************************/
// DecisionTree // DecisionTree
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {
} }
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) { root_(root) {
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) { DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y)); root_ = NodePtr(new Leaf(y));
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2); auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1, DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) { const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) { const std::vector<Y>& ys) {
@ -425,11 +421,10 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) { const std::string& table) {
// Convert std::string to values of type Y // Convert std::string to values of type Y
std::vector<Y> ys; std::vector<Y> ys;
std::istringstream iss(table); std::istringstream iss(table);
@ -440,14 +435,14 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree( template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) { Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label); root_ = compose(begin, end, label);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
@ -456,7 +451,7 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
@ -466,7 +461,7 @@ namespace gtsam {
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X, typename Func> template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
// Called by two constructors above. // Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new // Takes a label and a corresponding range of decision trees, and creates a
// decision tree. However, the order of the labels needs to be respected, so we // new decision tree. However, the order of the labels needs to be respected,
// cannot just create a root Choice node on the label: if the label is not the // so we cannot just create a root Choice node on the label: if the label is
// highest label, we need to do a complicated and expensive recursive call. // not the highest label, we need a complicated/ expensive recursive call.
template<typename L, typename Y> template<typename Iterator> template <typename L, typename Y>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin, template <typename Iterator>
Iterator end, const L& label) const { typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches // find highest label among branches
boost::optional<L> highestLabel; boost::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
} }
} }
/*********************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
size_t size = endY - beginY; size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves. // Create a simple choice node with values as leaves.
if (size != nrChoices) { if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl; std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY); auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,8 +592,8 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual
// If leaf, apply unary conversion "op" and create a unique leaf // functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
@ -612,18 +610,18 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(auto && branch: choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. // Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function. explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
@ -647,13 +645,13 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. // Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>; using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
@ -681,7 +679,7 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func, typename X>
@ -690,7 +688,7 @@ namespace gtsam {
return x0; return x0;
} }
/*********************************************************************************/ /****************************************************************************/
// labels is just done with a visit // labels is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
@ -702,7 +700,7 @@ namespace gtsam {
return unique; return unique;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
@ -736,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
@ -752,7 +750,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
// The way this works: // The way this works:
// We have an ADT, picture it as a tree. // We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label". // At a certain depth, we have a branch on "label".
@ -772,7 +770,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter, const LabelFormatter& labelFormatter,
@ -790,9 +788,11 @@ namespace gtsam {
bool showZero) const { bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result =
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); .c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
} }
template <typename L, typename Y> template <typename L, typename Y>
@ -804,8 +804,6 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/*********************************************************************************/ /******************************************************************************/
} // namespace gtsam
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector>
#include <set> #include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -39,7 +41,6 @@ namespace gtsam {
*/ */
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree { class DecisionTree {
protected: protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
@ -47,7 +48,6 @@ namespace gtsam {
} }
public: public:
using LabelFormatter = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>; using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>; using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; struct Leaf;
class Choice; struct Choice;
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { struct Node {
public:
using Ptr = boost::shared_ptr<const Node>; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor // Constructor
Node() { Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
// Destructor // Destructor
virtual ~Node() { virtual ~Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
@ -111,7 +112,6 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr; using NodePtr = typename Node::Ptr;
@ -119,8 +119,9 @@ namespace gtsam {
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities,
/** Internal recursive function to create from keys, cardinalities, and Y values */ * and Y values
*/
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X) const;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree(); DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label); DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */ /** Create DecisionTree from two others */
DecisionTree(const L& label, // DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f1);
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
@ -289,7 +289,8 @@ namespace gtsam {
} }
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */ /** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const { DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,14 +314,13 @@ namespace gtsam {
/// @{ /// @{
// internal use only // internal use only
DecisionTree(const NodePtr& root); explicit DecisionTree(const NodePtr& root);
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
}; // DecisionTree }; // DecisionTree
/** free versions of apply */ /** free versions of apply */
@ -340,11 +340,19 @@ namespace gtsam {
return f.apply(g, op); return f.apply(g, op);
} }
/// unzip a DecisionTree if its leaves are `std::pair` /**
template<typename L, typename T1, typename T2> * @brief unzip a DecisionTree with `std::pair` values.
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) { *
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }), * @param input the DecisionTree with `(T1,T2)` values.
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; })); * @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
} }
} // namespace gtsam } // namespace gtsam