diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 4266ace15..434a6beac 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -482,8 +483,8 @@ namespace gtsam { /****************************************************************************/ // DecisionTree /****************************************************************************/ - template - DecisionTree::DecisionTree() {} + template + DecisionTree::DecisionTree() : root_(nullptr) {} template DecisionTree::DecisionTree(const NodePtr& root) : @@ -951,11 +952,16 @@ namespace gtsam { return root_->equals(*other.root_); } + /****************************************************************************/ template const Y& DecisionTree::operator()(const Assignment& x) const { + if (root_ == nullptr) + throw std::invalid_argument( + "DecisionTree::operator() called on empty tree"); return root_->operator ()(x); } + /****************************************************************************/ template DecisionTree DecisionTree::apply(const Unary& op) const { // It is unclear what should happen if tree is empty: @@ -966,6 +972,7 @@ namespace gtsam { return DecisionTree(root_->apply(op)); } + /****************************************************************************/ /// Apply unary operator with assignment template DecisionTree DecisionTree::apply( @@ -1049,6 +1056,18 @@ namespace gtsam { return ss.str(); } -/******************************************************************************/ + /******************************************************************************/ + template + template + std::pair, DecisionTree> DecisionTree::split( + std::function(const Y&)> AB_of_Y) { + using AB = std::pair; + const DecisionTree ab(*this, AB_of_Y); + const DecisionTree a(ab, [](const AB& p) { return p.first; }); + const DecisionTree b(ab, [](const AB& p) { return p.second; }); + return {a, b}; + } + + /******************************************************************************/ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 6d8d86530..6a966e0dd 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -156,10 +156,10 @@ namespace gtsam { template static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); - /** Internal helper function to create from - * keys, cardinalities, and Y values. - * Calls `build` which builds thetree bottom-up, - * before we prune in a top-down fashion. + /** + * Internal helper function to create a tree from keys, cardinalities, and Y + * values. Calls `build` which builds the tree bottom-up, before we prune in + * a top-down fashion. */ template static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); @@ -239,7 +239,7 @@ namespace gtsam { DecisionTree(const DecisionTree& other, Func Y_of_X); /** - * @brief Convert from a different value type X to value type Y, also transate + * @brief Convert from a different value type X to value type Y, also translate * labels via map from type M to L. * * @tparam M Previous label type. @@ -406,6 +406,18 @@ namespace gtsam { const ValueFormatter& valueFormatter, bool showZero = true) const; + /** + * @brief Convert into two trees with value types A and B. + * + * @tparam A First new value type. + * @tparam B Second new value type. + * @param AB_of_Y Functor to convert from type X to std::pair. + * @return A pair of DecisionTrees with value types A and B respectively. + */ + template + std::pair, DecisionTree> split( + std::function(const Y&)> AB_of_Y); + /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c625e1ba6..1382fc704 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -11,7 +11,7 @@ /* * @file testDecisionTree.cpp - * @brief Develop DecisionTree + * @brief DecisionTree unit tests * @author Frank Dellaert * @author Can Erdogan * @date Jan 30, 2012 @@ -271,6 +271,37 @@ TEST(DecisionTree, Example) { DOT(acnotb); } +/* ************************************************************************** */ +// Test that we can create two trees out of one, using a function that returns a pair. +TEST(DecisionTree, Split) { + // Create labels + string A("A"), B("B"); + + // Create a decision tree + DT original(A, DT(B, 1, 2), DT(B, 3, 4)); + + // Define a function that returns an int/bool pair + auto split_function = [](const int& value) -> std::pair { + return {value*3, value*3 % 2 == 0}; + }; + + // Split the original tree into two new trees + auto [la,lb] = original.split(split_function); + + // Check the first resulting tree + EXPECT_LONGS_EQUAL(3, la(Assignment{{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(6, la(Assignment{{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(9, la(Assignment{{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(12, la(Assignment{{A, 1}, {B, 1}})); + + // Check the second resulting tree + EXPECT(!lb(Assignment{{A, 0}, {B, 0}})); + EXPECT(lb(Assignment{{A, 0}, {B, 1}})); + EXPECT(!lb(Assignment{{A, 1}, {B, 0}})); + EXPECT(lb(Assignment{{A, 1}, {B, 1}})); +} + + /* ************************************************************************** */ // test Conversion of values bool bool_of_int(const int& y) { return y != 0; };