diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 37130ab72..17a38f7cf 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -115,11 +115,11 @@ namespace gtsam { template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { - std::function map_function = [&map](const M& label) -> L { + std::function L_of_M = [&map](const M& label) -> L { return map.at(label); }; std::function op = Ring::id; - this->root_ = this->template convert(other.root_, op, map_function); + this->root_ = this->template convertFrom(other.root_, L_of_M, op); } /** sum */ diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index b31773702..af52a6daf 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -461,20 +461,21 @@ namespace gtsam { template template DecisionTree::DecisionTree(const DecisionTree& other, - std::function op) { - auto map = [](const L& label) { return label; }; - root_ = other.template convert(op, map); + std::function Y_of_X) { + auto L_of_L = [](const L& label) { return label; }; + root_ = convertFrom(Y_of_X, L_of_L); } /*********************************************************************************/ - template - template + template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - std::function map_function = [&map](const M& label) -> L { + const std::map& map, + std::function Y_of_X) { + std::function L_of_M = [&map](const M& label) -> L { return map.at(label); }; - root_ = other.template convert(op, map_function); + root_ = convertFrom(other.root_, L_of_M, Y_of_X); } /*********************************************************************************/ @@ -589,9 +590,10 @@ namespace gtsam { /*********************************************************************************/ template template - typename DecisionTree::NodePtr DecisionTree::convert( + typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, - std::function op, std::function map) const { + std::function L_of_M, + std::function Y_of_X) const { typedef DecisionTree MX; typedef typename MX::Leaf MXLeaf; typedef typename MX::Choice MXChoice; @@ -601,7 +603,7 @@ namespace gtsam { // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf auto leaf = boost::dynamic_pointer_cast(f); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice auto choice = boost::dynamic_pointer_cast(f); @@ -610,12 +612,12 @@ namespace gtsam { // get new label const M oldLabel = choice->label(); - const L newLabel = map(oldLabel); + const L newLabel = L_of_M(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, op, map)); + LY converted(convertFrom(branch, L_of_M, Y_of_X)); functions += converted; } return LY::compose(functions.begin(), functions.end(), newLabel); diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 4c79c5841..ecc3d17dc 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -113,27 +113,20 @@ namespace gtsam { /** A function is a shared pointer to the root of a DT */ typedef typename Node::Ptr NodePtr; - protected: - - /* a DecisionTree just contains the root */ + /// a DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; + protected: + /** Internal recursive function to create from keys, cardinalities, and Y values */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /// Convert to a different type, will not convert label if map empty. + /// Convert from a DecisionTree. template - NodePtr convert(const typename DecisionTree::NodePtr& f, - std::function op, - std::function map) const; - - /// Convert to a different type, will not convert label if map empty. - template - NodePtr convert(std::function op, - std::function map) const { - return convert(root_, op, map); - } + NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X) const; public: @@ -169,12 +162,12 @@ namespace gtsam { /** Convert from a different type. */ template DecisionTree(const DecisionTree& other, - std::function op); + std::function Y_of_X); /** Convert from a different type, also transate labels via map. */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); + template + DecisionTree(const DecisionTree& other, const std::map& L_of_M, + std::function Y_of_X); /// @} /// @name Testable diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index f42a590ae..cc61a382f 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -84,8 +84,11 @@ GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) /* ******************************************************************************** */ struct DT : public DecisionTree { + using Base = DecisionTree; using DecisionTree::DecisionTree; - DT(const DecisionTree& dt) : root_(dt.root_) {} + DT() = default; + + DT(const Base& dt) : Base(dt) {} /// print to stdout void print(const std::string& s = "") const { @@ -93,15 +96,13 @@ struct DT : public DecisionTree { auto valueFormatter = [](const int& v) { return (boost::format("%d") % v).str(); }; - DecisionTree::print("", keyFormatter, valueFormatter); + Base::print("", keyFormatter, valueFormatter); + } + /// Equality method customized to int node type + bool equals(const Base& other, double tol = 1e-9) const { + auto compare = [](const int& v, const int& w) { return v == w; }; + return Base::equals(other, compare); } - // /// Equality method customized to int node type - // bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const { - // auto compare = [tol](const int& v, const int& w) { - // return v.a == w.a && std::abs(v.b - w.b) < tol; - // }; - // return DecisionTree::equals(other, compare); - // } }; // traits @@ -131,111 +132,111 @@ struct Ring { /* ******************************************************************************** */ // test DT -// TEST(DT, example) -// { -// // Create labels -// string A("A"), B("B"), C("C"); +TEST(DT, example) +{ + // Create labels + string A("A"), B("B"), C("C"); -// // create a value -// Assignment x00, x01, x10, x11; -// x00[A] = 0, x00[B] = 0; -// x01[A] = 0, x01[B] = 1; -// x10[A] = 1, x10[B] = 0; -// x11[A] = 1, x11[B] = 1; + // create a value + Assignment x00, x01, x10, x11; + x00[A] = 0, x00[B] = 0; + x01[A] = 0, x01[B] = 1; + x10[A] = 1, x10[B] = 0; + x11[A] = 1, x11[B] = 1; -// // empty -// DT empty; + // empty + DT empty; -// // A -// DT a(A, 0, 5); -// LONGS_EQUAL(0,a(x00)) -// LONGS_EQUAL(5,a(x10)) -// DOT(a); + // A + DT a(A, 0, 5); + LONGS_EQUAL(0,a(x00)) + LONGS_EQUAL(5,a(x10)) + DOT(a); -// // pruned -// DT p(A, 2, 2); -// LONGS_EQUAL(2,p(x00)) -// LONGS_EQUAL(2,p(x10)) -// DOT(p); + // pruned + DT p(A, 2, 2); + LONGS_EQUAL(2,p(x00)) + LONGS_EQUAL(2,p(x10)) + DOT(p); -// // \neg B -// DT notb(B, 5, 0); -// LONGS_EQUAL(5,notb(x00)) -// LONGS_EQUAL(5,notb(x10)) -// DOT(notb); + // \neg B + DT notb(B, 5, 0); + LONGS_EQUAL(5,notb(x00)) + LONGS_EQUAL(5,notb(x10)) + DOT(notb); -// // Check supplying empty trees yields an exception -// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); -// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); -// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); + // Check supplying empty trees yields an exception + CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); + CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); + CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); -// // apply, two nodes, in natural order -// DT anotb = apply(a, notb, &Ring::mul); -// LONGS_EQUAL(0,anotb(x00)) -// LONGS_EQUAL(0,anotb(x01)) -// LONGS_EQUAL(25,anotb(x10)) -// LONGS_EQUAL(0,anotb(x11)) -// DOT(anotb); + // apply, two nodes, in natural order + DT anotb = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,anotb(x00)) + LONGS_EQUAL(0,anotb(x01)) + LONGS_EQUAL(25,anotb(x10)) + LONGS_EQUAL(0,anotb(x11)) + DOT(anotb); -// // check pruning -// DT pnotb = apply(p, notb, &Ring::mul); -// LONGS_EQUAL(10,pnotb(x00)) -// LONGS_EQUAL( 0,pnotb(x01)) -// LONGS_EQUAL(10,pnotb(x10)) -// LONGS_EQUAL( 0,pnotb(x11)) -// DOT(pnotb); + // check pruning + DT pnotb = apply(p, notb, &Ring::mul); + LONGS_EQUAL(10,pnotb(x00)) + LONGS_EQUAL( 0,pnotb(x01)) + LONGS_EQUAL(10,pnotb(x10)) + LONGS_EQUAL( 0,pnotb(x11)) + DOT(pnotb); -// // check pruning -// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); -// LONGS_EQUAL(0,zeros(x00)) -// LONGS_EQUAL(0,zeros(x01)) -// LONGS_EQUAL(0,zeros(x10)) -// LONGS_EQUAL(0,zeros(x11)) -// DOT(zeros); + // check pruning + DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); + LONGS_EQUAL(0,zeros(x00)) + LONGS_EQUAL(0,zeros(x01)) + LONGS_EQUAL(0,zeros(x10)) + LONGS_EQUAL(0,zeros(x11)) + DOT(zeros); -// // apply, two nodes, in switched order -// DT notba = apply(a, notb, &Ring::mul); -// LONGS_EQUAL(0,notba(x00)) -// LONGS_EQUAL(0,notba(x01)) -// LONGS_EQUAL(25,notba(x10)) -// LONGS_EQUAL(0,notba(x11)) -// DOT(notba); + // apply, two nodes, in switched order + DT notba = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,notba(x00)) + LONGS_EQUAL(0,notba(x01)) + LONGS_EQUAL(25,notba(x10)) + LONGS_EQUAL(0,notba(x11)) + DOT(notba); -// // Test choose 0 -// DT actual0 = notba.choose(A, 0); -// EXPECT(assert_equal(DT(0.0), actual0)); -// DOT(actual0); + // Test choose 0 + DT actual0 = notba.choose(A, 0); + EXPECT(assert_equal(DT(0.0), actual0)); + DOT(actual0); -// // Test choose 1 -// DT actual1 = notba.choose(A, 1); -// EXPECT(assert_equal(DT(B, 25, 0), actual1)); -// DOT(actual1); + // Test choose 1 + DT actual1 = notba.choose(A, 1); + EXPECT(assert_equal(DT(B, 25, 0), actual1)); + DOT(actual1); -// // apply, two nodes at same level -// DT a_and_a = apply(a, a, &Ring::mul); -// LONGS_EQUAL(0,a_and_a(x00)) -// LONGS_EQUAL(0,a_and_a(x01)) -// LONGS_EQUAL(25,a_and_a(x10)) -// LONGS_EQUAL(25,a_and_a(x11)) -// DOT(a_and_a); + // apply, two nodes at same level + DT a_and_a = apply(a, a, &Ring::mul); + LONGS_EQUAL(0,a_and_a(x00)) + LONGS_EQUAL(0,a_and_a(x01)) + LONGS_EQUAL(25,a_and_a(x10)) + LONGS_EQUAL(25,a_and_a(x11)) + DOT(a_and_a); -// // create a function on C -// DT c(C, 0, 5); + // create a function on C + DT c(C, 0, 5); -// // and a model assigning stuff to C -// Assignment x101; -// x101[A] = 1, x101[B] = 0, x101[C] = 1; + // and a model assigning stuff to C + Assignment x101; + x101[A] = 1, x101[B] = 0, x101[C] = 1; -// // mul notba with C -// DT notbac = apply(notba, c, &Ring::mul); -// LONGS_EQUAL(125,notbac(x101)) -// DOT(notbac); + // mul notba with C + DT notbac = apply(notba, c, &Ring::mul); + LONGS_EQUAL(125,notbac(x101)) + DOT(notbac); -// // mul now in different order -// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); -// LONGS_EQUAL(125,acnotb(x101)) -// DOT(acnotb); -// } + // mul now in different order + DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); + LONGS_EQUAL(125,acnotb(x101)) + DOT(acnotb); +} /* ******************************************************************************** */ // test Conversion @@ -243,9 +244,6 @@ enum Label { U, V, X, Y, Z }; typedef DecisionTree BDT; -bool convert(const int& y) { - return y != 0; -} TEST(DT, conversion) { @@ -259,8 +257,10 @@ TEST(DT, conversion) map ordering; ordering[A] = X; ordering[B] = Y; - std::function op = convert; - BDT f2(f1, ordering, op); + std::function bool_of_int = [](const int& y) { + return y != 0; + }; + BDT f2(f1, ordering, bool_of_int); // f1.print("f1"); // f2.print("f2");