Fixed thresholding and fold example
							parent
							
								
									fa1cde2f60
								
							
						
					
					
						commit
						8db7f25021
					
				| 
						 | 
					@ -24,8 +24,8 @@ using namespace boost::assign;
 | 
				
			||||||
#include <gtsam/base/Testable.h>
 | 
					#include <gtsam/base/Testable.h>
 | 
				
			||||||
#include <gtsam/discrete/Signature.h>
 | 
					#include <gtsam/discrete/Signature.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//#define DT_DEBUG_MEMORY
 | 
					// #define DT_DEBUG_MEMORY
 | 
				
			||||||
//#define DT_NO_PRUNING
 | 
					// #define DT_NO_PRUNING
 | 
				
			||||||
#define DISABLE_DOT
 | 
					#define DISABLE_DOT
 | 
				
			||||||
#include <gtsam/discrete/DecisionTree-inl.h>
 | 
					#include <gtsam/discrete/DecisionTree-inl.h>
 | 
				
			||||||
using namespace std;
 | 
					using namespace std;
 | 
				
			||||||
| 
						 | 
					@ -349,10 +349,10 @@ TEST(DecisionTree, visitWith) {
 | 
				
			||||||
TEST(DecisionTree, fold) {
 | 
					TEST(DecisionTree, fold) {
 | 
				
			||||||
  // Create small two-level tree
 | 
					  // Create small two-level tree
 | 
				
			||||||
  string A("A"), B("B");
 | 
					  string A("A"), B("B");
 | 
				
			||||||
  DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
 | 
					  DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
 | 
				
			||||||
  auto add = [](const int& y, double x) { return y + x; };
 | 
					  auto add = [](const int& y, double x) { return y + x; };
 | 
				
			||||||
  double sum = tree.fold(add, 0.0);
 | 
					  double sum = tree.fold(add, 0.0);
 | 
				
			||||||
  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
 | 
					  EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);  // Note, not 7, due to pruning!
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************** */
 | 
					/* ************************************************************************** */
 | 
				
			||||||
| 
						 | 
					@ -365,7 +365,7 @@ TEST(DecisionTree, labels) {
 | 
				
			||||||
  EXPECT_LONGS_EQUAL(2, labels.size());
 | 
					  EXPECT_LONGS_EQUAL(2, labels.size());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ******************************************************************************** */
 | 
					/* ************************************************************************** */
 | 
				
			||||||
// Test unzip method.
 | 
					// Test unzip method.
 | 
				
			||||||
TEST(DecisionTree, unzip) {
 | 
					TEST(DecisionTree, unzip) {
 | 
				
			||||||
  using DTP = DecisionTree<string, std::pair<int, string>>;
 | 
					  using DTP = DecisionTree<string, std::pair<int, string>>;
 | 
				
			||||||
| 
						 | 
					@ -374,15 +374,13 @@ TEST(DecisionTree, unzip) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Create small two-level tree
 | 
					  // Create small two-level tree
 | 
				
			||||||
  string A("A"), B("B"), C("C");
 | 
					  string A("A"), B("B"), C("C");
 | 
				
			||||||
  DTP tree(B,
 | 
					  DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
 | 
				
			||||||
           DTP(A, {0, "zero"}, {1, "one"}),
 | 
					           DTP(A, {2, "two"}, {1337, "l33t"}));
 | 
				
			||||||
           DTP(A, {2, "two"}, {1337, "l33t"})
 | 
					 | 
				
			||||||
  );
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  DT1 dt1;
 | 
					  DT1 dt1;
 | 
				
			||||||
  DT2 dt2;
 | 
					  DT2 dt2;
 | 
				
			||||||
  std::tie(dt1, dt2) = unzip(tree);
 | 
					  std::tie(dt1, dt2) = unzip(tree);
 | 
				
			||||||
  
 | 
					
 | 
				
			||||||
  DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
 | 
					  DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
 | 
				
			||||||
  DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
 | 
					  DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -398,7 +396,7 @@ TEST(DecisionTree, threshold) {
 | 
				
			||||||
  keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
 | 
					  keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
 | 
				
			||||||
  DT tree(keys, "0 1 2 3 4 5 6 7");
 | 
					  DT tree(keys, "0 1 2 3 4 5 6 7");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Check number of elements equal to zero
 | 
					  // Check number of leaves equal to zero
 | 
				
			||||||
  auto count = [](const int& value, int count) {
 | 
					  auto count = [](const int& value, int count) {
 | 
				
			||||||
    return value == 0 ? count + 1 : count;
 | 
					    return value == 0 ? count + 1 : count;
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
| 
						 | 
					@ -408,9 +406,9 @@ TEST(DecisionTree, threshold) {
 | 
				
			||||||
  auto threshold = [](int value) { return value < 5 ? 0 : value; };
 | 
					  auto threshold = [](int value) { return value < 5 ? 0 : value; };
 | 
				
			||||||
  DT thresholded(tree, threshold);
 | 
					  DT thresholded(tree, threshold);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Check number of elements equal to zero now = 5
 | 
					  // Check number of leaves equal to zero now = 2
 | 
				
			||||||
  // TODO(frank): it is 2, because the pruned branches are counted as 1!
 | 
					  // Note: it is 2, because the pruned branches are counted as 1!
 | 
				
			||||||
  EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
 | 
					  EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************* */
 | 
					/* ************************************************************************* */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue