DT::split
parent
b56595c6f8
commit
6c9b25c45e
|
@ -29,6 +29,7 @@
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -482,8 +483,8 @@ namespace gtsam {
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {}
|
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
||||||
|
@ -951,11 +952,16 @@ namespace gtsam {
|
||||||
return root_->equals(*other.root_);
|
return root_->equals(*other.root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
||||||
|
if (root_ == nullptr)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DecisionTree::operator() called on empty tree");
|
||||||
return root_->operator ()(x);
|
return root_->operator ()(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||||
// It is unclear what should happen if tree is empty:
|
// It is unclear what should happen if tree is empty:
|
||||||
|
@ -966,6 +972,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
/// Apply unary operator with assignment
|
/// Apply unary operator with assignment
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||||
|
@ -1049,6 +1056,18 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/******************************************************************************/
|
/******************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y) {
|
||||||
|
using AB = std::pair<A, B>;
|
||||||
|
const DecisionTree<L, AB> ab(*this, AB_of_Y);
|
||||||
|
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
|
||||||
|
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
|
||||||
|
return {a, b};
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -156,10 +156,10 @@ namespace gtsam {
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
|
||||||
/** Internal helper function to create from
|
/**
|
||||||
* keys, cardinalities, and Y values.
|
* Internal helper function to create a tree from keys, cardinalities, and Y
|
||||||
* Calls `build` which builds thetree bottom-up,
|
* values. Calls `build` which builds the tree bottom-up, before we prune in
|
||||||
* before we prune in a top-down fashion.
|
* a top-down fashion.
|
||||||
*/
|
*/
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
@ -239,7 +239,7 @@ namespace gtsam {
|
||||||
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type X to value type Y, also transate
|
* @brief Convert from a different value type X to value type Y, also translate
|
||||||
* labels via map from type M to L.
|
* labels via map from type M to L.
|
||||||
*
|
*
|
||||||
* @tparam M Previous label type.
|
* @tparam M Previous label type.
|
||||||
|
@ -406,6 +406,18 @@ namespace gtsam {
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert into two trees with value types A and B.
|
||||||
|
*
|
||||||
|
* @tparam A First new value type.
|
||||||
|
* @tparam B Second new value type.
|
||||||
|
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
|
||||||
|
* @return A pair of DecisionTrees with value types A and B respectively.
|
||||||
|
*/
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y);
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDecisionTree.cpp
|
* @file testDecisionTree.cpp
|
||||||
* @brief Develop DecisionTree
|
* @brief DecisionTree unit tests
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Can Erdogan
|
* @author Can Erdogan
|
||||||
* @date Jan 30, 2012
|
* @date Jan 30, 2012
|
||||||
|
@ -271,6 +271,37 @@ TEST(DecisionTree, Example) {
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test that we can create two trees out of one, using a function that returns a pair.
|
||||||
|
TEST(DecisionTree, Split) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// Create a decision tree
|
||||||
|
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||||
|
|
||||||
|
// Define a function that returns an int/bool pair
|
||||||
|
auto split_function = [](const int& value) -> std::pair<int, bool> {
|
||||||
|
return {value*3, value*3 % 2 == 0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Split the original tree into two new trees
|
||||||
|
auto [la,lb] = original.split<int,bool>(split_function);
|
||||||
|
|
||||||
|
// Check the first resulting tree
|
||||||
|
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
|
||||||
|
// Check the second resulting tree
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
|
|
Loading…
Reference in New Issue