Thresholding test
parent
9317e94452
commit
241906d2c9
|
|
@ -308,7 +308,7 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree tree;
|
StringContainerTree tree;
|
||||||
|
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
// Check conversion
|
// Check conversion
|
||||||
|
|
@ -324,7 +324,7 @@ TEST(DecisionTree, Containers) {
|
||||||
// Test visit.
|
// Test visit.
|
||||||
TEST(DecisionTree, visit) {
|
TEST(DecisionTree, visit) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](int y) { sum += y; };
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
|
@ -336,7 +336,7 @@ TEST(DecisionTree, visit) {
|
||||||
// Test visit, with Choices argument.
|
// Test visit, with Choices argument.
|
||||||
TEST(DecisionTree, visitWith) {
|
TEST(DecisionTree, visitWith) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
|
@ -348,7 +348,7 @@ TEST(DecisionTree, visitWith) {
|
||||||
// Test fold.
|
// Test fold.
|
||||||
TEST(DecisionTree, fold) {
|
TEST(DecisionTree, fold) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
auto add = [](const int& y, double x) { return y + x; };
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
double sum = tree.fold(add, 0.0);
|
double sum = tree.fold(add, 0.0);
|
||||||
|
|
@ -359,14 +359,14 @@ TEST(DecisionTree, fold) {
|
||||||
// Test retrieving all labels.
|
// Test retrieving all labels.
|
||||||
TEST(DecisionTree, labels) {
|
TEST(DecisionTree, labels) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
auto labels = tree.labels();
|
auto labels = tree.labels();
|
||||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test unzip method.
|
||||||
TEST(DecisionTree, unzip) {
|
TEST(DecisionTree, unzip) {
|
||||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
using DT1 = DecisionTree<string, int>;
|
using DT1 = DecisionTree<string, int>;
|
||||||
|
|
@ -390,6 +390,29 @@ TEST(DecisionTree, unzip) {
|
||||||
EXPECT(tree2.equals(dt2));
|
EXPECT(tree2.equals(dt2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test thresholding.
|
||||||
|
TEST(DecisionTree, threshold) {
|
||||||
|
// Create three level tree
|
||||||
|
vector<DT::LabelC> keys;
|
||||||
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
|
// Check number of elements equal to zero
|
||||||
|
auto count = [](const int& value, int count) {
|
||||||
|
return value == 0 ? count + 1 : count;
|
||||||
|
};
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||||
|
|
||||||
|
// Now threshold
|
||||||
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
|
// Check number of elements equal to zero now = 5
|
||||||
|
// TODO(frank): it is 2, because the pruned branches are counted as 1!
|
||||||
|
EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue