DT::split

release/4.3a0
Frank Dellaert 2024-10-15 14:34:42 +09:00
parent b56595c6f8
commit 6c9b25c45e
3 changed files with 71 additions and 9 deletions

View File

@ -29,6 +29,7 @@
#include <optional>
#include <set>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
@ -482,8 +483,8 @@ namespace gtsam {
/****************************************************************************/
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {}
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
@ -951,11 +952,16 @@ namespace gtsam {
return root_->equals(*other.root_);
}
/****************************************************************************/
template<typename L, typename Y>
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
if (root_ == nullptr)
throw std::invalid_argument(
"DecisionTree::operator() called on empty tree");
return root_->operator ()(x);
}
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
@ -966,6 +972,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}
/****************************************************************************/
/// Apply unary operator with assignment
template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
@ -1049,6 +1056,18 @@ namespace gtsam {
return ss.str();
}
/******************************************************************************/
/******************************************************************************/
template <typename L, typename Y>
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) {
using AB = std::pair<A, B>;
const DecisionTree<L, AB> ab(*this, AB_of_Y);
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
return {a, b};
}
/******************************************************************************/
} // namespace gtsam

View File

@ -156,10 +156,10 @@ namespace gtsam {
template <typename It, typename ValueIt>
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
/**
* Internal helper function to create a tree from keys, cardinalities, and Y
* values. Calls `build` which builds the tree bottom-up, before we prune in
* a top-down fashion.
*/
template <typename It, typename ValueIt>
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
@ -239,7 +239,7 @@ namespace gtsam {
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
/**
* @brief Convert from a different value type X to value type Y, also transate
* @brief Convert from a different value type X to value type Y, also translate
* labels via map from type M to L.
*
* @tparam M Previous label type.
@ -406,6 +406,18 @@ namespace gtsam {
const ValueFormatter& valueFormatter,
bool showZero = true) const;
/**
* @brief Convert into two trees with value types A and B.
*
* @tparam A First new value type.
* @tparam B Second new value type.
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
* @return A pair of DecisionTrees with value types A and B respectively.
*/
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y);
/// @name Advanced Interface
/// @{

View File

@ -11,7 +11,7 @@
/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @brief DecisionTree unit tests
* @author Frank Dellaert
* @author Can Erdogan
* @date Jan 30, 2012
@ -271,6 +271,37 @@ TEST(DecisionTree, Example) {
DOT(acnotb);
}
/* ************************************************************************** */
// Test that we can create two trees out of one, using a function that returns a pair.
TEST(DecisionTree, Split) {
// Create labels
string A("A"), B("B");
// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
// Define a function that returns an int/bool pair
auto split_function = [](const int& value) -> std::pair<int, bool> {
return {value*3, value*3 % 2 == 0};
};
// Split the original tree into two new trees
auto [la,lb] = original.split<int,bool>(split_function);
// Check the first resulting tree
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
// Check the second resulting tree
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
}
/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };