/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /* * @file testDecisionTree.cpp * @brief Develop DecisionTree * @author Frank Dellaert * @author Can Erdogan * @date Jan 30, 2012 */ #include using namespace boost::assign; #include #include #include // #define DT_DEBUG_MEMORY // #define DT_NO_PRUNING #define DISABLE_DOT #include using namespace std; using namespace gtsam; template void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } #define DOT(x) (dot(x, #x)) struct Crazy { int a; double b; }; struct CrazyDecisionTree : public DecisionTree { /// print to stdout void print(const std::string& s = "") const { auto keyFormatter = [](const std::string& s) { return s; }; auto valueFormatter = [](const Crazy& v) { return (boost::format("{%d,%4.2g}") % v.a % v.b).str(); }; DecisionTree::print("", keyFormatter, valueFormatter); } /// Equality method customized to Crazy node type bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { auto compare = [tol](const Crazy& v, const Crazy& w) { return v.a == w.a && std::abs(v.b - w.b) < tol; }; return DecisionTree::equals(other, compare); } }; // traits namespace gtsam { template <> struct traits : public Testable {}; } // namespace gtsam GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) /* ************************************************************************** */ // Test string labels and int range /* ************************************************************************** */ struct DT : public DecisionTree { using Base = DecisionTree; using DecisionTree::DecisionTree; DT() = default; DT(const Base& dt) : Base(dt) {} /// print to stdout void print(const std::string& s = "") const { auto keyFormatter = [](const std::string& s) { return s; }; auto valueFormatter = [](const int& v) { return (boost::format("%d") % v).str(); }; Base::print("", keyFormatter, valueFormatter); } /// Equality method customized to int node type bool equals(const Base& other, double tol = 1e-9) const { auto compare = [](const int& v, const int& w) { return v == w; }; return Base::equals(other, compare); } }; // traits namespace gtsam { template <> struct traits
: public Testable
{}; } // namespace gtsam GTSAM_CONCEPT_TESTABLE_INST(DT) struct Ring { static inline int zero() { return 0; } static inline int one() { return 1; } static inline int id(const int& a) { return a; } static inline int add(const int& a, const int& b) { return a + b; } static inline int mul(const int& a, const int& b) { return a * b; } }; /* ************************************************************************** */ // test DT TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); // create a value Assignment x00, x01, x10, x11; x00[A] = 0, x00[B] = 0; x01[A] = 0, x01[B] = 1; x10[A] = 1, x10[B] = 0; x11[A] = 1, x11[B] = 1; // empty DT empty; // A DT a(A, 0, 5); LONGS_EQUAL(0, a(x00)) LONGS_EQUAL(5, a(x10)) DOT(a); // pruned DT p(A, 2, 2); LONGS_EQUAL(2, p(x00)) LONGS_EQUAL(2, p(x10)) DOT(p); // \neg B DT notb(B, 5, 0); LONGS_EQUAL(5, notb(x00)) LONGS_EQUAL(5, notb(x10)) DOT(notb); // Check supplying empty trees yields an exception CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); LONGS_EQUAL(0, anotb(x00)) LONGS_EQUAL(0, anotb(x01)) LONGS_EQUAL(25, anotb(x10)) LONGS_EQUAL(0, anotb(x11)) DOT(anotb); // check pruning DT pnotb = apply(p, notb, &Ring::mul); LONGS_EQUAL(10, pnotb(x00)) LONGS_EQUAL(0, pnotb(x01)) LONGS_EQUAL(10, pnotb(x10)) LONGS_EQUAL(0, pnotb(x11)) DOT(pnotb); // check pruning DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); LONGS_EQUAL(0, zeros(x00)) LONGS_EQUAL(0, zeros(x01)) LONGS_EQUAL(0, zeros(x10)) LONGS_EQUAL(0, zeros(x11)) DOT(zeros); // apply, two nodes, in switched order DT notba = apply(a, notb, &Ring::mul); LONGS_EQUAL(0, notba(x00)) LONGS_EQUAL(0, notba(x01)) LONGS_EQUAL(25, notba(x10)) LONGS_EQUAL(0, notba(x11)) DOT(notba); // Test choose 0 DT actual0 = notba.choose(A, 0); EXPECT(assert_equal(DT(0.0), actual0)); DOT(actual0); // Test choose 1 DT actual1 = notba.choose(A, 1); EXPECT(assert_equal(DT(B, 25, 0), actual1)); DOT(actual1); // apply, two nodes at same level DT a_and_a = apply(a, a, &Ring::mul); LONGS_EQUAL(0, a_and_a(x00)) LONGS_EQUAL(0, a_and_a(x01)) LONGS_EQUAL(25, a_and_a(x10)) LONGS_EQUAL(25, a_and_a(x11)) DOT(a_and_a); // create a function on C DT c(C, 0, 5); // and a model assigning stuff to C Assignment x101; x101[A] = 1, x101[B] = 0, x101[C] = 1; // mul notba with C DT notbac = apply(notba, c, &Ring::mul); LONGS_EQUAL(125, notbac(x101)) DOT(notbac); // mul now in different order DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); LONGS_EQUAL(125, acnotb(x101)) DOT(acnotb); } /* ************************************************************************** */ // test Conversion of values bool bool_of_int(const int& y) { return y != 0; }; typedef DecisionTree StringBoolTree; TEST(DecisionTree, ConvertValuesOnly) { // Create labels string A("A"), B("B"); // apply, two nodes, in natural order DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); // convert StringBoolTree f2(f1, bool_of_int); // Check a value Assignment x00; x00["A"] = 0, x00["B"] = 0; EXPECT(!f2(x00)); } /* ************************************************************************** */ // test Conversion of both values and labels. enum Label { U, V, X, Y, Z }; typedef DecisionTree LabelBoolTree; TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); // apply, two nodes, in natural order DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); // convert map ordering; ordering[A] = X; ordering[B] = Y; LabelBoolTree f2(f1, ordering, &bool_of_int); // Check some values Assignment