Renamed protected method convert -> convertFrom

release/4.3a0
Frank Dellaert 2022-01-02 18:08:45 -05:00
parent 5c4038c7c0
commit 6c23fd1e86
4 changed files with 130 additions and 135 deletions

View File

@ -115,11 +115,11 @@ namespace gtsam {
template<typename M> template<typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
std::function<L(const M&)> map_function = [&map](const M& label) -> L { std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label); return map.at(label);
}; };
std::function<double(const double&)> op = Ring::id; std::function<double(const double&)> op = Ring::id;
this->root_ = this->template convert(other.root_, op, map_function); this->root_ = this->template convertFrom(other.root_, L_of_M, op);
} }
/** sum */ /** sum */

View File

@ -461,20 +461,21 @@ namespace gtsam {
template <typename L, typename Y> template <typename L, typename Y>
template <typename X> template <typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> op) { std::function<Y(const X&)> Y_of_X) {
auto map = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = other.template convert<L, X>(op, map); root_ = convertFrom<L, X>(Y_of_X, L_of_L);
} }
/*********************************************************************************/ /*********************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
const std::map<M, L>& map, std::function<Y(const X&)> op) { const std::map<M, L>& map,
std::function<L(const M&)> map_function = [&map](const M& label) -> L { std::function<Y(const X&)> Y_of_X) {
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
return map.at(label); return map.at(label);
}; };
root_ = other.template convert<M, X>(op, map_function); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /*********************************************************************************/
@ -589,9 +590,10 @@ namespace gtsam {
/*********************************************************************************/ /*********************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f, const typename DecisionTree<M, X>::NodePtr& f,
std::function<Y(const X&)> op, std::function<L(const M&)> map) const { std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
typedef DecisionTree<M, X> MX; typedef DecisionTree<M, X> MX;
typedef typename MX::Leaf MXLeaf; typedef typename MX::Leaf MXLeaf;
typedef typename MX::Choice MXChoice; typedef typename MX::Choice MXChoice;
@ -601,7 +603,7 @@ namespace gtsam {
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual functions
// If leaf, apply unary conversion "op" and create a unique leaf // If leaf, apply unary conversion "op" and create a unique leaf
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f); auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f); auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
@ -610,12 +612,12 @@ namespace gtsam {
// get new label // get new label
const M oldLabel = choice->label(); const M oldLabel = choice->label();
const L newLabel = map(oldLabel); const L newLabel = L_of_M(oldLabel);
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(const MXNodePtr& branch: choice->branches()) { for(const MXNodePtr& branch: choice->branches()) {
LY converted(convert<M, X>(branch, op, map)); LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
functions += converted; functions += converted;
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);

View File

@ -113,27 +113,20 @@ namespace gtsam {
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
typedef typename Node::Ptr NodePtr; typedef typename Node::Ptr NodePtr;
protected: /// a DecisionTree just contains the root. TODO(dellaert): make protected.
/* a DecisionTree just contains the root */
NodePtr root_; NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities, and Y values */ /** Internal recursive function to create from keys, cardinalities, and Y values */
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/// Convert to a different type, will not convert label if map empty. /// Convert from a DecisionTree<M, X>.
template <typename M, typename X> template <typename M, typename X>
NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f, NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<Y(const X&)> op, std::function<L(const M&)> L_of_M,
std::function<L(const M&)> map) const; std::function<Y(const X&)> Y_of_X) const;
/// Convert to a different type, will not convert label if map empty.
template <typename M, typename X>
NodePtr convert(std::function<Y(const X&)> op,
std::function<L(const M&)> map) const {
return convert(root_, op, map);
}
public: public:
@ -169,12 +162,12 @@ namespace gtsam {
/** Convert from a different type. */ /** Convert from a different type. */
template <typename X> template <typename X>
DecisionTree(const DecisionTree<L, X>& other, DecisionTree(const DecisionTree<L, X>& other,
std::function<Y(const X&)> op); std::function<Y(const X&)> Y_of_X);
/** Convert from a different type, also transate labels via map. */ /** Convert from a different type, also transate labels via map. */
template <typename M, typename X> template <typename M, typename X>
DecisionTree(const DecisionTree<M, X>& other, DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
const std::map<M, L>& map, std::function<Y(const X&)> op); std::function<Y(const X&)> Y_of_X);
/// @} /// @}
/// @name Testable /// @name Testable

View File

@ -84,8 +84,11 @@ GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ******************************************************************************** */
struct DT : public DecisionTree<string, int> { struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>;
using DecisionTree::DecisionTree; using DecisionTree::DecisionTree;
DT(const DecisionTree<string, int>& dt) : root_(dt.root_) {} DT() = default;
DT(const Base& dt) : Base(dt) {}
/// print to stdout /// print to stdout
void print(const std::string& s = "") const { void print(const std::string& s = "") const {
@ -93,15 +96,13 @@ struct DT : public DecisionTree<string, int> {
auto valueFormatter = [](const int& v) { auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str(); return (boost::format("%d") % v).str();
}; };
DecisionTree<string, int>::print("", keyFormatter, valueFormatter); Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
return Base::equals(other, compare);
} }
// /// Equality method customized to int node type
// bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
// auto compare = [tol](const int& v, const int& w) {
// return v.a == w.a && std::abs(v.b - w.b) < tol;
// };
// return DecisionTree<string, int>::equals(other, compare);
// }
}; };
// traits // traits
@ -131,111 +132,111 @@ struct Ring {
/* ******************************************************************************** */ /* ******************************************************************************** */
// test DT // test DT
// TEST(DT, example) TEST(DT, example)
// { {
// // Create labels // Create labels
// string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
// // create a value // create a value
// Assignment<string> x00, x01, x10, x11; Assignment<string> x00, x01, x10, x11;
// x00[A] = 0, x00[B] = 0; x00[A] = 0, x00[B] = 0;
// x01[A] = 0, x01[B] = 1; x01[A] = 0, x01[B] = 1;
// x10[A] = 1, x10[B] = 0; x10[A] = 1, x10[B] = 0;
// x11[A] = 1, x11[B] = 1; x11[A] = 1, x11[B] = 1;
// // empty // empty
// DT empty; DT empty;
// // A // A
// DT a(A, 0, 5); DT a(A, 0, 5);
// LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0,a(x00))
// LONGS_EQUAL(5,a(x10)) LONGS_EQUAL(5,a(x10))
// DOT(a); DOT(a);
// // pruned // pruned
// DT p(A, 2, 2); DT p(A, 2, 2);
// LONGS_EQUAL(2,p(x00)) LONGS_EQUAL(2,p(x00))
// LONGS_EQUAL(2,p(x10)) LONGS_EQUAL(2,p(x10))
// DOT(p); DOT(p);
// // \neg B // \neg B
// DT notb(B, 5, 0); DT notb(B, 5, 0);
// LONGS_EQUAL(5,notb(x00)) LONGS_EQUAL(5,notb(x00))
// LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5,notb(x10))
// DOT(notb); DOT(notb);
// // Check supplying empty trees yields an exception // Check supplying empty trees yields an exception
// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error); CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error); CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error); CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
// // apply, two nodes, in natural order // apply, two nodes, in natural order
// DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
// LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0,anotb(x00))
// LONGS_EQUAL(0,anotb(x01)) LONGS_EQUAL(0,anotb(x01))
// LONGS_EQUAL(25,anotb(x10)) LONGS_EQUAL(25,anotb(x10))
// LONGS_EQUAL(0,anotb(x11)) LONGS_EQUAL(0,anotb(x11))
// DOT(anotb); DOT(anotb);
// // check pruning // check pruning
// DT pnotb = apply(p, notb, &Ring::mul); DT pnotb = apply(p, notb, &Ring::mul);
// LONGS_EQUAL(10,pnotb(x00)) LONGS_EQUAL(10,pnotb(x00))
// LONGS_EQUAL( 0,pnotb(x01)) LONGS_EQUAL( 0,pnotb(x01))
// LONGS_EQUAL(10,pnotb(x10)) LONGS_EQUAL(10,pnotb(x10))
// LONGS_EQUAL( 0,pnotb(x11)) LONGS_EQUAL( 0,pnotb(x11))
// DOT(pnotb); DOT(pnotb);
// // check pruning // check pruning
// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
// LONGS_EQUAL(0,zeros(x00)) LONGS_EQUAL(0,zeros(x00))
// LONGS_EQUAL(0,zeros(x01)) LONGS_EQUAL(0,zeros(x01))
// LONGS_EQUAL(0,zeros(x10)) LONGS_EQUAL(0,zeros(x10))
// LONGS_EQUAL(0,zeros(x11)) LONGS_EQUAL(0,zeros(x11))
// DOT(zeros); DOT(zeros);
// // apply, two nodes, in switched order // apply, two nodes, in switched order
// DT notba = apply(a, notb, &Ring::mul); DT notba = apply(a, notb, &Ring::mul);
// LONGS_EQUAL(0,notba(x00)) LONGS_EQUAL(0,notba(x00))
// LONGS_EQUAL(0,notba(x01)) LONGS_EQUAL(0,notba(x01))
// LONGS_EQUAL(25,notba(x10)) LONGS_EQUAL(25,notba(x10))
// LONGS_EQUAL(0,notba(x11)) LONGS_EQUAL(0,notba(x11))
// DOT(notba); DOT(notba);
// // Test choose 0 // Test choose 0
// DT actual0 = notba.choose(A, 0); DT actual0 = notba.choose(A, 0);
// EXPECT(assert_equal(DT(0.0), actual0)); EXPECT(assert_equal(DT(0.0), actual0));
// DOT(actual0); DOT(actual0);
// // Test choose 1 // Test choose 1
// DT actual1 = notba.choose(A, 1); DT actual1 = notba.choose(A, 1);
// EXPECT(assert_equal(DT(B, 25, 0), actual1)); EXPECT(assert_equal(DT(B, 25, 0), actual1));
// DOT(actual1); DOT(actual1);
// // apply, two nodes at same level // apply, two nodes at same level
// DT a_and_a = apply(a, a, &Ring::mul); DT a_and_a = apply(a, a, &Ring::mul);
// LONGS_EQUAL(0,a_and_a(x00)) LONGS_EQUAL(0,a_and_a(x00))
// LONGS_EQUAL(0,a_and_a(x01)) LONGS_EQUAL(0,a_and_a(x01))
// LONGS_EQUAL(25,a_and_a(x10)) LONGS_EQUAL(25,a_and_a(x10))
// LONGS_EQUAL(25,a_and_a(x11)) LONGS_EQUAL(25,a_and_a(x11))
// DOT(a_and_a); DOT(a_and_a);
// // create a function on C // create a function on C
// DT c(C, 0, 5); DT c(C, 0, 5);
// // and a model assigning stuff to C // and a model assigning stuff to C
// Assignment<string> x101; Assignment<string> x101;
// x101[A] = 1, x101[B] = 0, x101[C] = 1; x101[A] = 1, x101[B] = 0, x101[C] = 1;
// // mul notba with C // mul notba with C
// DT notbac = apply(notba, c, &Ring::mul); DT notbac = apply(notba, c, &Ring::mul);
// LONGS_EQUAL(125,notbac(x101)) LONGS_EQUAL(125,notbac(x101))
// DOT(notbac); DOT(notbac);
// // mul now in different order // mul now in different order
// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
// LONGS_EQUAL(125,acnotb(x101)) LONGS_EQUAL(125,acnotb(x101))
// DOT(acnotb); DOT(acnotb);
// } }
/* ******************************************************************************** */ /* ******************************************************************************** */
// test Conversion // test Conversion
@ -243,9 +244,6 @@ enum Label {
U, V, X, Y, Z U, V, X, Y, Z
}; };
typedef DecisionTree<Label, bool> BDT; typedef DecisionTree<Label, bool> BDT;
bool convert(const int& y) {
return y != 0;
}
TEST(DT, conversion) TEST(DT, conversion)
{ {
@ -259,8 +257,10 @@ TEST(DT, conversion)
map<string, Label> ordering; map<string, Label> ordering;
ordering[A] = X; ordering[A] = X;
ordering[B] = Y; ordering[B] = Y;
std::function<bool(const int&)> op = convert; std::function<bool(const int&)> bool_of_int = [](const int& y) {
BDT f2(f1, ordering, op); return y != 0;
};
BDT f2(f1, ordering, bool_of_int);
// f1.print("f1"); // f1.print("f1");
// f2.print("f2"); // f2.print("f2");