Merge pull request #1005 from borglab/feature/better_decision_tree
commit
14ec0ae04b
|
|
@ -28,6 +28,7 @@
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/type_traits/has_dereference.hpp>
|
#include <boost/type_traits/has_dereference.hpp>
|
||||||
#include <boost/unordered_set.hpp>
|
#include <boost/unordered_set.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
|
@ -82,13 +83,7 @@ namespace gtsam {
|
||||||
return compare(this->constant_, other->constant_);
|
return compare(this->constant_, other->constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** print */
|
||||||
* @brief Print method.
|
|
||||||
*
|
|
||||||
* @param s Prefix string.
|
|
||||||
* @param labelFormatter Functor to format the labels of type L.
|
|
||||||
* @param valueFormatter Functor to format the values of type Y.
|
|
||||||
*/
|
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||||
|
|
@ -332,7 +327,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
NodePtr apply(const Unary& op) const override {
|
NodePtr apply(const Unary& op) const override {
|
||||||
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
|
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -347,24 +342,24 @@ namespace gtsam {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf node, recurse on branches
|
// If second argument of binary op is Leaf node, recurse on branches
|
||||||
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
|
||||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||||
for(NodePtr branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
h->push_back(fL.apply_f_op_g(*branch, op));
|
h->push_back(fL.apply_f_op_g(*branch, op));
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If second argument of binary op is Choice, call constructor
|
// If second argument of binary op is Choice, call constructor
|
||||||
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
|
||||||
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
|
auto h = boost::make_shared<Choice>(fC, *this, op);
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename OP>
|
template<typename OP>
|
||||||
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
|
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
|
||||||
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
|
auto h = boost::make_shared<Choice>(label(), nrChoices());
|
||||||
for(const NodePtr& branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
h->push_back(branch->apply_f_op_g(gL, op));
|
h->push_back(branch->apply_f_op_g(gL, op));
|
||||||
return Unique(h);
|
return Unique(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -374,9 +369,9 @@ namespace gtsam {
|
||||||
return branches_[index]; // choose branch
|
return branches_[index]; // choose branch
|
||||||
|
|
||||||
// second case, not label of interest, just recurse
|
// second case, not label of interest, just recurse
|
||||||
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
|
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||||
for(const NodePtr& branch: branches_)
|
for (auto&& branch : branches_)
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -401,10 +396,9 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(//
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
const L& label, const Y& y1, const Y& y2) {
|
auto a = boost::make_shared<Choice>(label, 2);
|
||||||
boost::shared_ptr<Choice> a(new Choice(label, 2));
|
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(l1);
|
||||||
a->push_back(l2);
|
a->push_back(l2);
|
||||||
|
|
@ -412,12 +406,12 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(//
|
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||||
const LabelC& labelC, const Y& y1, const Y& y2) {
|
const Y& y2) {
|
||||||
if (labelC.second != 2) throw std::invalid_argument(
|
if (labelC.second != 2) throw std::invalid_argument(
|
||||||
"DecisionTree: binary constructor called with non-binary label");
|
"DecisionTree: binary constructor called with non-binary label");
|
||||||
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
|
auto a = boost::make_shared<Choice>(labelC.first, 2);
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(l1);
|
||||||
a->push_back(l2);
|
a->push_back(l2);
|
||||||
|
|
@ -465,23 +459,20 @@ namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
std::function<Y(const X&)> Y_of_X) {
|
Func Y_of_X) {
|
||||||
// Define functor for identity mapping of node label.
|
// Define functor for identity mapping of node label.
|
||||||
auto L_of_L = [](const L& label) { return label; };
|
auto L_of_L = [](const L& label) { return label; };
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
const std::map<M, L>& map,
|
const std::map<M, L>& map, Func Y_of_X) {
|
||||||
std::function<Y(const X&)> Y_of_X) {
|
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
|
||||||
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
|
||||||
return map.at(label);
|
|
||||||
};
|
|
||||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -511,13 +502,14 @@ namespace gtsam {
|
||||||
|
|
||||||
// if label is already in correct order, just put together a choice on label
|
// if label is already in correct order, just put together a choice on label
|
||||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||||
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
|
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
|
||||||
for (Iterator it = begin; it != end; it++)
|
for (Iterator it = begin; it != end; it++)
|
||||||
choiceOnLabel->push_back(it->root_);
|
choiceOnLabel->push_back(it->root_);
|
||||||
return Choice::Unique(choiceOnLabel);
|
return Choice::Unique(choiceOnLabel);
|
||||||
} else {
|
} else {
|
||||||
// Set up a new choice on the highest label
|
// Set up a new choice on the highest label
|
||||||
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
|
auto choiceOnHighestLabel =
|
||||||
|
boost::make_shared<Choice>(*highestLabel, nrChoices);
|
||||||
// now, for all possible values of highestLabel
|
// now, for all possible values of highestLabel
|
||||||
for (size_t index = 0; index < nrChoices; index++) {
|
for (size_t index = 0; index < nrChoices; index++) {
|
||||||
// make a new set of functions for composing by iterating over the given
|
// make a new set of functions for composing by iterating over the given
|
||||||
|
|
@ -576,7 +568,7 @@ namespace gtsam {
|
||||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
||||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||||
}
|
}
|
||||||
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
|
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||||
for (ValueIt y = beginY; y != endY; y++)
|
for (ValueIt y = beginY; y != endY; y++)
|
||||||
choice->push_back(NodePtr(new Leaf(*y)));
|
choice->push_back(NodePtr(new Leaf(*y)));
|
||||||
return Choice::Unique(choice);
|
return Choice::Unique(choice);
|
||||||
|
|
@ -589,7 +581,7 @@ namespace gtsam {
|
||||||
size_t split = size / nrChoices;
|
size_t split = size / nrChoices;
|
||||||
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
||||||
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
|
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
|
||||||
functions += DecisionTree(f);
|
functions.emplace_back(f);
|
||||||
}
|
}
|
||||||
return compose(functions.begin(), functions.end(), begin->first);
|
return compose(functions.begin(), functions.end(), begin->first);
|
||||||
}
|
}
|
||||||
|
|
@ -601,18 +593,16 @@ namespace gtsam {
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M,
|
std::function<L(const M&)> L_of_M,
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
using MX = DecisionTree<M, X>;
|
|
||||||
using MXLeaf = typename MX::Leaf;
|
|
||||||
using MXChoice = typename MX::Choice;
|
|
||||||
using MXNodePtr = typename MX::NodePtr;
|
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual functions
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||||
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice) throw std::invalid_argument(
|
||||||
"DecisionTree::Convert: Invalid NodePtr");
|
"DecisionTree::Convert: Invalid NodePtr");
|
||||||
|
|
@ -623,14 +613,93 @@ namespace gtsam {
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// put together via Shannon expansion otherwise not sorted.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for(const MXNodePtr& branch: choice->branches()) {
|
for(auto && branch: choice->branches()) {
|
||||||
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
functions += converted;
|
|
||||||
}
|
}
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
|
template <typename L, typename Y>
|
||||||
|
struct Visit {
|
||||||
|
using F = std::function<void(const Y&)>;
|
||||||
|
Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(leaf->constant());
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visit(Func f) const {
|
||||||
|
Visit<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
|
template <typename L, typename Y>
|
||||||
|
struct VisitWith {
|
||||||
|
using Choices = Assignment<L>;
|
||||||
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
|
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(choices, leaf->constant());
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
|
choices[choice->label()] = i; // Set assignment for label to i
|
||||||
|
(*this)(choice->branches()[i]); // recurse!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visitWith(Func f) const {
|
||||||
|
VisitWith<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
// fold is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
||||||
|
visit([&](const Y& y) { x0 = f(y, x0); });
|
||||||
|
return x0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
// labels is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
std::set<L> unique;
|
||||||
|
auto f = [&](const Assignment<L>& choices, const Y&) {
|
||||||
|
for (auto&& kv : choices) unique.insert(kv.first);
|
||||||
|
};
|
||||||
|
visitWith(f);
|
||||||
|
return unique;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -176,9 +177,8 @@ namespace gtsam {
|
||||||
* @param other The DecisionTree to convert from.
|
* @param other The DecisionTree to convert from.
|
||||||
* @param Y_of_X Functor to convert from value type X to type Y.
|
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||||
*/
|
*/
|
||||||
template <typename X>
|
template <typename X, typename Func>
|
||||||
DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||||
std::function<Y(const X&)> Y_of_X);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type X to value type Y, also transate
|
* @brief Convert from a different value type X to value type Y, also transate
|
||||||
|
|
@ -190,9 +190,9 @@ namespace gtsam {
|
||||||
* @param L_of_M Map from label type M to type L.
|
* @param L_of_M Map from label type M to type L.
|
||||||
* @param Y_of_X Functor to convert from type X to type Y.
|
* @param Y_of_X Functor to convert from type X to type Y.
|
||||||
*/
|
*/
|
||||||
template <typename M, typename X>
|
template <typename M, typename X, typename Func>
|
||||||
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
|
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
|
||||||
std::function<Y(const X&)> Y_of_X);
|
Func Y_of_X);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -229,6 +229,52 @@ namespace gtsam {
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
const Y& operator()(const Assignment<L>& x) const;
|
const Y& operator()(const Assignment<L>& x) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking a value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visit(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking an assignment and a value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visitWith(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
|
*
|
||||||
|
* @tparam X type for accumulator.
|
||||||
|
* @param f binary function: Y * X -> X returning an updated accumulator.
|
||||||
|
* @param x0 initial value for accumulator.
|
||||||
|
* @return X final value for accumulator.
|
||||||
|
*
|
||||||
|
* @note X is always passed by value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* auto add = [](const double& y, double x) { return y + x; };
|
||||||
|
* double sum = tree.fold(add, 0.0);
|
||||||
|
*/
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X fold(Func f, X x0) const;
|
||||||
|
|
||||||
|
/** Retrieve all unique labels as a set. */
|
||||||
|
std::set<L> labels() const;
|
||||||
|
|
||||||
/** apply Unary operation "op" to f */
|
/** apply Unary operation "op" to f */
|
||||||
DecisionTree apply(const Unary& op) const;
|
DecisionTree apply(const Unary& op) const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -123,8 +123,7 @@ struct Ring {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DT, example)
|
TEST(DecisionTree, example) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
|
@ -231,13 +230,10 @@ TEST(DT, example)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
std::function<bool(const int&)> bool_of_int = [](const int& y) {
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
return y != 0;
|
|
||||||
};
|
|
||||||
typedef DecisionTree<string, bool> StringBoolTree;
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
|
|
||||||
TEST(DT, ConvertValuesOnly)
|
TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
|
@ -260,8 +256,7 @@ enum Label {
|
||||||
};
|
};
|
||||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
TEST(DT, ConvertBoth)
|
TEST(DecisionTree, ConvertBoth) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
|
@ -272,7 +267,7 @@ TEST(DT, ConvertBoth)
|
||||||
map<string, Label> ordering;
|
map<string, Label> ordering;
|
||||||
ordering[A] = X;
|
ordering[A] = X;
|
||||||
ordering[B] = Y;
|
ordering[B] = Y;
|
||||||
LabelBoolTree f2(f1, ordering, bool_of_int);
|
LabelBoolTree f2(f1, ordering, &bool_of_int);
|
||||||
|
|
||||||
// Check some values
|
// Check some values
|
||||||
Assignment<Label> x00, x01, x10, x11;
|
Assignment<Label> x00, x01, x10, x11;
|
||||||
|
|
@ -288,8 +283,7 @@ TEST(DT, ConvertBoth)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DT, Compose)
|
TEST(DecisionTree, Compose) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
|
@ -314,6 +308,73 @@ TEST(DT, Compose)
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Check we can create a decision tree of containers.
|
||||||
|
TEST(DecisionTree, Containers) {
|
||||||
|
using Container = std::vector<double>;
|
||||||
|
using StringContainerTree = DecisionTree<string, Container>;
|
||||||
|
|
||||||
|
// Check default constructor
|
||||||
|
StringContainerTree tree;
|
||||||
|
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
|
// Check conversion
|
||||||
|
auto container_of_int = [](const int& i) {
|
||||||
|
Container c;
|
||||||
|
c.emplace_back(i);
|
||||||
|
return c;
|
||||||
|
};
|
||||||
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test visit.
|
||||||
|
TEST(DecisionTree, visit) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
tree.visit(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test visit, with Choices argument.
|
||||||
|
TEST(DecisionTree, visitWith) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
tree.visitWith(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test fold.
|
||||||
|
TEST(DecisionTree, fold) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
|
double sum = tree.fold(add, 0.0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test retrieving all labels.
|
||||||
|
TEST(DecisionTree, labels) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
auto labels = tree.labels();
|
||||||
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -142,10 +142,10 @@ public:
|
||||||
return q - (*this);
|
return q - (*this);
|
||||||
}
|
}
|
||||||
Vector6 GTSAM_DEPRECATED localCoordinates(const ConstantBias& q) {
|
Vector6 GTSAM_DEPRECATED localCoordinates(const ConstantBias& q) {
|
||||||
return between(q).vector();
|
return (q - (*this)).vector();
|
||||||
}
|
}
|
||||||
ConstantBias GTSAM_DEPRECATED retract(const Vector6& v) {
|
ConstantBias GTSAM_DEPRECATED retract(const Vector6& v) {
|
||||||
return compose(ConstantBias(v));
|
return (*this) + ConstantBias(v);
|
||||||
}
|
}
|
||||||
static Vector6 GTSAM_DEPRECATED Logmap(const ConstantBias& p) {
|
static Vector6 GTSAM_DEPRECATED Logmap(const ConstantBias& p) {
|
||||||
return p.vector();
|
return p.vector();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue