DT::split
							parent
							
								
									b56595c6f8
								
							
						
					
					
						commit
						6c9b25c45e
					
				|  | @ -29,6 +29,7 @@ | |||
| #include <optional> | ||||
| #include <set> | ||||
| #include <sstream> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
|  | @ -482,8 +483,8 @@ namespace gtsam { | |||
|   /****************************************************************************/ | ||||
|   // DecisionTree
 | ||||
|   /****************************************************************************/ | ||||
|   template<typename L, typename Y> | ||||
|   DecisionTree<L, Y>::DecisionTree() {} | ||||
|   template <typename L, typename Y> | ||||
|   DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {} | ||||
| 
 | ||||
|   template<typename L, typename Y> | ||||
|   DecisionTree<L, Y>::DecisionTree(const NodePtr& root) : | ||||
|  | @ -951,11 +952,16 @@ namespace gtsam { | |||
|     return root_->equals(*other.root_); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template<typename L, typename Y> | ||||
|   const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const { | ||||
|     if (root_ == nullptr) | ||||
|       throw std::invalid_argument( | ||||
|           "DecisionTree::operator() called on empty tree"); | ||||
|     return root_->operator ()(x); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template<typename L, typename Y> | ||||
|   DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const { | ||||
|     // It is unclear what should happen if tree is empty:
 | ||||
|  | @ -966,6 +972,7 @@ namespace gtsam { | |||
|     return DecisionTree(root_->apply(op)); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   /// Apply unary operator with assignment
 | ||||
|   template <typename L, typename Y> | ||||
|   DecisionTree<L, Y> DecisionTree<L, Y>::apply( | ||||
|  | @ -1049,6 +1056,18 @@ namespace gtsam { | |||
|     return ss.str(); | ||||
|   } | ||||
| 
 | ||||
| /******************************************************************************/ | ||||
|   /******************************************************************************/ | ||||
|   template <typename L, typename Y> | ||||
|   template <typename A, typename B> | ||||
|   std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split( | ||||
|       std::function<std::pair<A, B>(const Y&)> AB_of_Y) { | ||||
|     using AB = std::pair<A, B>; | ||||
|     const DecisionTree<L, AB> ab(*this, AB_of_Y); | ||||
|     const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; }); | ||||
|     const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; }); | ||||
|     return {a, b}; | ||||
|   } | ||||
| 
 | ||||
|   /******************************************************************************/ | ||||
| 
 | ||||
|   }  // namespace gtsam
 | ||||
|  |  | |||
|  | @ -156,10 +156,10 @@ namespace gtsam { | |||
|     template <typename It, typename ValueIt> | ||||
|     static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); | ||||
| 
 | ||||
|     /** Internal helper function to create from
 | ||||
|      * keys, cardinalities, and Y values. | ||||
|      * Calls `build` which builds thetree bottom-up, | ||||
|      * before we prune in a top-down fashion. | ||||
|     /**
 | ||||
|      * Internal helper function to create a tree from keys, cardinalities, and Y | ||||
|      * values. Calls `build` which builds the tree bottom-up, before we prune in | ||||
|      * a top-down fashion. | ||||
|      */ | ||||
|     template <typename It, typename ValueIt> | ||||
|     static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); | ||||
|  | @ -239,7 +239,7 @@ namespace gtsam { | |||
|     DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X); | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Convert from a different value type X to value type Y, also transate | ||||
|      * @brief Convert from a different value type X to value type Y, also translate | ||||
|      * labels via map from type M to L. | ||||
|      * | ||||
|      * @tparam M Previous label type. | ||||
|  | @ -406,6 +406,18 @@ namespace gtsam { | |||
|                     const ValueFormatter& valueFormatter, | ||||
|                     bool showZero = true) const; | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Convert into two trees with value types A and B. | ||||
|      * | ||||
|      * @tparam A First new value type. | ||||
|      * @tparam B Second new value type. | ||||
|      * @param AB_of_Y Functor to convert from type X to std::pair<A, B>. | ||||
|      * @return A pair of DecisionTrees with value types A and B respectively. | ||||
|      */ | ||||
|     template <typename A, typename B> | ||||
|     std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split( | ||||
|         std::function<std::pair<A, B>(const Y&)> AB_of_Y); | ||||
| 
 | ||||
|     /// @name Advanced Interface
 | ||||
|     /// @{
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ | |||
| 
 | ||||
| /*
 | ||||
|  * @file    testDecisionTree.cpp | ||||
|  * @brief    Develop DecisionTree | ||||
|  * @brief   DecisionTree unit tests | ||||
|  * @author  Frank Dellaert | ||||
|  * @author  Can Erdogan | ||||
|  * @date    Jan 30, 2012 | ||||
|  | @ -271,6 +271,37 @@ TEST(DecisionTree, Example) { | |||
|   DOT(acnotb); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test that we can create two trees out of one, using a function that returns a pair.
 | ||||
| TEST(DecisionTree, Split) { | ||||
|   // Create labels
 | ||||
|   string A("A"), B("B"); | ||||
| 
 | ||||
|   // Create a decision tree
 | ||||
|   DT original(A, DT(B, 1, 2), DT(B, 3, 4)); | ||||
| 
 | ||||
|   // Define a function that returns an int/bool pair
 | ||||
|   auto split_function = [](const int& value) -> std::pair<int, bool> { | ||||
|     return {value*3, value*3 % 2 == 0}; | ||||
|   }; | ||||
| 
 | ||||
|   // Split the original tree into two new trees
 | ||||
|   auto [la,lb] = original.split<int,bool>(split_function); | ||||
| 
 | ||||
|   // Check the first resulting tree
 | ||||
|   EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}})); | ||||
|   EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}})); | ||||
|   EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}})); | ||||
|   EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}})); | ||||
| 
 | ||||
|   // Check the second resulting tree
 | ||||
|   EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}})); | ||||
|   EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}})); | ||||
|   EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}})); | ||||
|   EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}})); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // test Conversion of values
 | ||||
| bool bool_of_int(const int& y) { return y != 0; }; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue