Fixed thresholding and fold example
							parent
							
								
									fa1cde2f60
								
							
						
					
					
						commit
						8db7f25021
					
				| 
						 | 
				
			
			@ -24,8 +24,8 @@ using namespace boost::assign;
 | 
			
		|||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/Signature.h>
 | 
			
		||||
 | 
			
		||||
//#define DT_DEBUG_MEMORY
 | 
			
		||||
//#define DT_NO_PRUNING
 | 
			
		||||
// #define DT_DEBUG_MEMORY
 | 
			
		||||
// #define DT_NO_PRUNING
 | 
			
		||||
#define DISABLE_DOT
 | 
			
		||||
#include <gtsam/discrete/DecisionTree-inl.h>
 | 
			
		||||
using namespace std;
 | 
			
		||||
| 
						 | 
				
			
			@ -349,10 +349,10 @@ TEST(DecisionTree, visitWith) {
 | 
			
		|||
TEST(DecisionTree, fold) {
 | 
			
		||||
  // Create small two-level tree
 | 
			
		||||
  string A("A"), B("B");
 | 
			
		||||
  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
 | 
			
		||||
  DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
 | 
			
		||||
  auto add = [](const int& y, double x) { return y + x; };
 | 
			
		||||
  double sum = tree.fold(add, 0.0);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);  // Note, not 7, due to pruning!
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************** */
 | 
			
		||||
| 
						 | 
				
			
			@ -365,7 +365,7 @@ TEST(DecisionTree, labels) {
 | 
			
		|||
  EXPECT_LONGS_EQUAL(2, labels.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ******************************************************************************** */
 | 
			
		||||
/* ************************************************************************** */
 | 
			
		||||
// Test unzip method.
 | 
			
		||||
TEST(DecisionTree, unzip) {
 | 
			
		||||
  using DTP = DecisionTree<string, std::pair<int, string>>;
 | 
			
		||||
| 
						 | 
				
			
			@ -374,10 +374,8 @@ TEST(DecisionTree, unzip) {
 | 
			
		|||
 | 
			
		||||
  // Create small two-level tree
 | 
			
		||||
  string A("A"), B("B"), C("C");
 | 
			
		||||
  DTP tree(B,
 | 
			
		||||
           DTP(A, {0, "zero"}, {1, "one"}),
 | 
			
		||||
           DTP(A, {2, "two"}, {1337, "l33t"})
 | 
			
		||||
  );
 | 
			
		||||
  DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
 | 
			
		||||
           DTP(A, {2, "two"}, {1337, "l33t"}));
 | 
			
		||||
 | 
			
		||||
  DT1 dt1;
 | 
			
		||||
  DT2 dt2;
 | 
			
		||||
| 
						 | 
				
			
			@ -398,7 +396,7 @@ TEST(DecisionTree, threshold) {
 | 
			
		|||
  keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
 | 
			
		||||
  DT tree(keys, "0 1 2 3 4 5 6 7");
 | 
			
		||||
 | 
			
		||||
  // Check number of elements equal to zero
 | 
			
		||||
  // Check number of leaves equal to zero
 | 
			
		||||
  auto count = [](const int& value, int count) {
 | 
			
		||||
    return value == 0 ? count + 1 : count;
 | 
			
		||||
  };
 | 
			
		||||
| 
						 | 
				
			
			@ -408,9 +406,9 @@ TEST(DecisionTree, threshold) {
 | 
			
		|||
  auto threshold = [](int value) { return value < 5 ? 0 : value; };
 | 
			
		||||
  DT thresholded(tree, threshold);
 | 
			
		||||
 | 
			
		||||
  // Check number of elements equal to zero now = 5
 | 
			
		||||
  // TODO(frank): it is 2, because the pruned branches are counted as 1!
 | 
			
		||||
  EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
 | 
			
		||||
  // Check number of leaves equal to zero now = 2
 | 
			
		||||
  // Note: it is 2, because the pruned branches are counted as 1!
 | 
			
		||||
  EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue