DT::split
parent
b56595c6f8
commit
6c9b25c45e
|
@ -29,6 +29,7 @@
|
|||
#include <optional>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -482,8 +483,8 @@ namespace gtsam {
|
|||
/****************************************************************************/
|
||||
// DecisionTree
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree() {}
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
|
||||
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
||||
|
@ -951,11 +952,16 @@ namespace gtsam {
|
|||
return root_->equals(*other.root_);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
||||
if (root_ == nullptr)
|
||||
throw std::invalid_argument(
|
||||
"DecisionTree::operator() called on empty tree");
|
||||
return root_->operator ()(x);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||
// It is unclear what should happen if tree is empty:
|
||||
|
@ -966,6 +972,7 @@ namespace gtsam {
|
|||
return DecisionTree(root_->apply(op));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
/// Apply unary operator with assignment
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||
|
@ -1049,6 +1056,18 @@ namespace gtsam {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
/******************************************************************************/
|
||||
/******************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
template <typename A, typename B>
|
||||
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
|
||||
std::function<std::pair<A, B>(const Y&)> AB_of_Y) {
|
||||
using AB = std::pair<A, B>;
|
||||
const DecisionTree<L, AB> ab(*this, AB_of_Y);
|
||||
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
|
||||
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
|
||||
return {a, b};
|
||||
}
|
||||
|
||||
/******************************************************************************/
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -156,10 +156,10 @@ namespace gtsam {
|
|||
template <typename It, typename ValueIt>
|
||||
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||
|
||||
/** Internal helper function to create from
|
||||
* keys, cardinalities, and Y values.
|
||||
* Calls `build` which builds thetree bottom-up,
|
||||
* before we prune in a top-down fashion.
|
||||
/**
|
||||
* Internal helper function to create a tree from keys, cardinalities, and Y
|
||||
* values. Calls `build` which builds the tree bottom-up, before we prune in
|
||||
* a top-down fashion.
|
||||
*/
|
||||
template <typename It, typename ValueIt>
|
||||
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||
|
@ -239,7 +239,7 @@ namespace gtsam {
|
|||
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||
|
||||
/**
|
||||
* @brief Convert from a different value type X to value type Y, also transate
|
||||
* @brief Convert from a different value type X to value type Y, also translate
|
||||
* labels via map from type M to L.
|
||||
*
|
||||
* @tparam M Previous label type.
|
||||
|
@ -406,6 +406,18 @@ namespace gtsam {
|
|||
const ValueFormatter& valueFormatter,
|
||||
bool showZero = true) const;
|
||||
|
||||
/**
|
||||
* @brief Convert into two trees with value types A and B.
|
||||
*
|
||||
* @tparam A First new value type.
|
||||
* @tparam B Second new value type.
|
||||
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
|
||||
* @return A pair of DecisionTrees with value types A and B respectively.
|
||||
*/
|
||||
template <typename A, typename B>
|
||||
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
|
||||
std::function<std::pair<A, B>(const Y&)> AB_of_Y);
|
||||
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
/*
|
||||
* @file testDecisionTree.cpp
|
||||
* @brief Develop DecisionTree
|
||||
* @brief DecisionTree unit tests
|
||||
* @author Frank Dellaert
|
||||
* @author Can Erdogan
|
||||
* @date Jan 30, 2012
|
||||
|
@ -271,6 +271,37 @@ TEST(DecisionTree, Example) {
|
|||
DOT(acnotb);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test that we can create two trees out of one, using a function that returns a pair.
|
||||
TEST(DecisionTree, Split) {
|
||||
// Create labels
|
||||
string A("A"), B("B");
|
||||
|
||||
// Create a decision tree
|
||||
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||
|
||||
// Define a function that returns an int/bool pair
|
||||
auto split_function = [](const int& value) -> std::pair<int, bool> {
|
||||
return {value*3, value*3 % 2 == 0};
|
||||
};
|
||||
|
||||
// Split the original tree into two new trees
|
||||
auto [la,lb] = original.split<int,bool>(split_function);
|
||||
|
||||
// Check the first resulting tree
|
||||
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
|
||||
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
|
||||
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
|
||||
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
|
||||
|
||||
// Check the second resulting tree
|
||||
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
|
||||
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
|
||||
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
|
||||
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************** */
|
||||
// test Conversion of values
|
||||
bool bool_of_int(const int& y) { return y != 0; };
|
||||
|
|
Loading…
Reference in New Issue