Renamed protected method convert -> convertFrom
parent
5c4038c7c0
commit
6c23fd1e86
|
|
@ -115,11 +115,11 @@ namespace gtsam {
|
||||||
template<typename M>
|
template<typename M>
|
||||||
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
|
||||||
const std::map<M, L>& map) {
|
const std::map<M, L>& map) {
|
||||||
std::function<L(const M&)> map_function = [&map](const M& label) -> L {
|
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||||
return map.at(label);
|
return map.at(label);
|
||||||
};
|
};
|
||||||
std::function<double(const double&)> op = Ring::id;
|
std::function<double(const double&)> op = Ring::id;
|
||||||
this->root_ = this->template convert(other.root_, op, map_function);
|
this->root_ = this->template convertFrom(other.root_, L_of_M, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** sum */
|
/** sum */
|
||||||
|
|
|
||||||
|
|
@ -461,20 +461,21 @@ namespace gtsam {
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X>
|
template <typename X>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
std::function<Y(const X&)> op) {
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
auto map = [](const L& label) { return label; };
|
auto L_of_L = [](const L& label) { return label; };
|
||||||
root_ = other.template convert<L, X>(op, map);
|
root_ = convertFrom<L, X>(Y_of_X, L_of_L);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
const std::map<M, L>& map, std::function<Y(const X&)> op) {
|
const std::map<M, L>& map,
|
||||||
std::function<L(const M&)> map_function = [&map](const M& label) -> L {
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
|
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
|
||||||
return map.at(label);
|
return map.at(label);
|
||||||
};
|
};
|
||||||
root_ = other.template convert<M, X>(op, map_function);
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
|
|
@ -589,9 +590,10 @@ namespace gtsam {
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convert(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> op, std::function<L(const M&)> map) const {
|
std::function<L(const M&)> L_of_M,
|
||||||
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
typedef DecisionTree<M, X> MX;
|
typedef DecisionTree<M, X> MX;
|
||||||
typedef typename MX::Leaf MXLeaf;
|
typedef typename MX::Leaf MXLeaf;
|
||||||
typedef typename MX::Choice MXChoice;
|
typedef typename MX::Choice MXChoice;
|
||||||
|
|
@ -601,7 +603,7 @@ namespace gtsam {
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual functions
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
|
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
|
||||||
if (leaf) return NodePtr(new Leaf(op(leaf->constant())));
|
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
|
|
@ -610,12 +612,12 @@ namespace gtsam {
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
const L newLabel = map(oldLabel);
|
const L newLabel = L_of_M(oldLabel);
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// put together via Shannon expansion otherwise not sorted.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for(const MXNodePtr& branch: choice->branches()) {
|
for(const MXNodePtr& branch: choice->branches()) {
|
||||||
LY converted(convert<M, X>(branch, op, map));
|
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
functions += converted;
|
functions += converted;
|
||||||
}
|
}
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
|
|
|
||||||
|
|
@ -113,27 +113,20 @@ namespace gtsam {
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
typedef typename Node::Ptr NodePtr;
|
typedef typename Node::Ptr NodePtr;
|
||||||
|
|
||||||
protected:
|
/// a DecisionTree just contains the root. TODO(dellaert): make protected.
|
||||||
|
|
||||||
/* a DecisionTree just contains the root */
|
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
/// Convert to a different type, will not convert label if map empty.
|
/// Convert from a DecisionTree<M, X>.
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
NodePtr convert(const typename DecisionTree<M, X>::NodePtr& f,
|
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> op,
|
std::function<L(const M&)> L_of_M,
|
||||||
std::function<L(const M&)> map) const;
|
std::function<Y(const X&)> Y_of_X) const;
|
||||||
|
|
||||||
/// Convert to a different type, will not convert label if map empty.
|
|
||||||
template <typename M, typename X>
|
|
||||||
NodePtr convert(std::function<Y(const X&)> op,
|
|
||||||
std::function<L(const M&)> map) const {
|
|
||||||
return convert(root_, op, map);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -169,12 +162,12 @@ namespace gtsam {
|
||||||
/** Convert from a different type. */
|
/** Convert from a different type. */
|
||||||
template <typename X>
|
template <typename X>
|
||||||
DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree(const DecisionTree<L, X>& other,
|
||||||
std::function<Y(const X&)> op);
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
/** Convert from a different type, also transate labels via map. */
|
/** Convert from a different type, also transate labels via map. */
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
|
||||||
const std::map<M, L>& map, std::function<Y(const X&)> op);
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,11 @@ GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
|
||||||
struct DT : public DecisionTree<string, int> {
|
struct DT : public DecisionTree<string, int> {
|
||||||
|
using Base = DecisionTree<string, int>;
|
||||||
using DecisionTree::DecisionTree;
|
using DecisionTree::DecisionTree;
|
||||||
DT(const DecisionTree<string, int>& dt) : root_(dt.root_) {}
|
DT() = default;
|
||||||
|
|
||||||
|
DT(const Base& dt) : Base(dt) {}
|
||||||
|
|
||||||
/// print to stdout
|
/// print to stdout
|
||||||
void print(const std::string& s = "") const {
|
void print(const std::string& s = "") const {
|
||||||
|
|
@ -93,15 +96,13 @@ struct DT : public DecisionTree<string, int> {
|
||||||
auto valueFormatter = [](const int& v) {
|
auto valueFormatter = [](const int& v) {
|
||||||
return (boost::format("%d") % v).str();
|
return (boost::format("%d") % v).str();
|
||||||
};
|
};
|
||||||
DecisionTree<string, int>::print("", keyFormatter, valueFormatter);
|
Base::print("", keyFormatter, valueFormatter);
|
||||||
|
}
|
||||||
|
/// Equality method customized to int node type
|
||||||
|
bool equals(const Base& other, double tol = 1e-9) const {
|
||||||
|
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||||
|
return Base::equals(other, compare);
|
||||||
}
|
}
|
||||||
// /// Equality method customized to int node type
|
|
||||||
// bool equals(const CrazyDecisionTree& other, double tol = 1e-9) const {
|
|
||||||
// auto compare = [tol](const int& v, const int& w) {
|
|
||||||
// return v.a == w.a && std::abs(v.b - w.b) < tol;
|
|
||||||
// };
|
|
||||||
// return DecisionTree<string, int>::equals(other, compare);
|
|
||||||
// }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
@ -131,111 +132,111 @@ struct Ring {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
// TEST(DT, example)
|
TEST(DT, example)
|
||||||
// {
|
{
|
||||||
// // Create labels
|
// Create labels
|
||||||
// string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
// // create a value
|
// create a value
|
||||||
// Assignment<string> x00, x01, x10, x11;
|
Assignment<string> x00, x01, x10, x11;
|
||||||
// x00[A] = 0, x00[B] = 0;
|
x00[A] = 0, x00[B] = 0;
|
||||||
// x01[A] = 0, x01[B] = 1;
|
x01[A] = 0, x01[B] = 1;
|
||||||
// x10[A] = 1, x10[B] = 0;
|
x10[A] = 1, x10[B] = 0;
|
||||||
// x11[A] = 1, x11[B] = 1;
|
x11[A] = 1, x11[B] = 1;
|
||||||
|
|
||||||
// // empty
|
// empty
|
||||||
// DT empty;
|
DT empty;
|
||||||
|
|
||||||
// // A
|
// A
|
||||||
// DT a(A, 0, 5);
|
DT a(A, 0, 5);
|
||||||
// LONGS_EQUAL(0,a(x00))
|
LONGS_EQUAL(0,a(x00))
|
||||||
// LONGS_EQUAL(5,a(x10))
|
LONGS_EQUAL(5,a(x10))
|
||||||
// DOT(a);
|
DOT(a);
|
||||||
|
|
||||||
// // pruned
|
// pruned
|
||||||
// DT p(A, 2, 2);
|
DT p(A, 2, 2);
|
||||||
// LONGS_EQUAL(2,p(x00))
|
LONGS_EQUAL(2,p(x00))
|
||||||
// LONGS_EQUAL(2,p(x10))
|
LONGS_EQUAL(2,p(x10))
|
||||||
// DOT(p);
|
DOT(p);
|
||||||
|
|
||||||
// // \neg B
|
// \neg B
|
||||||
// DT notb(B, 5, 0);
|
DT notb(B, 5, 0);
|
||||||
// LONGS_EQUAL(5,notb(x00))
|
LONGS_EQUAL(5,notb(x00))
|
||||||
// LONGS_EQUAL(5,notb(x10))
|
LONGS_EQUAL(5,notb(x10))
|
||||||
// DOT(notb);
|
DOT(notb);
|
||||||
|
|
||||||
// // Check supplying empty trees yields an exception
|
// Check supplying empty trees yields an exception
|
||||||
// CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
CHECK_EXCEPTION(apply(empty, &Ring::id), std::runtime_error);
|
||||||
// CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
CHECK_EXCEPTION(apply(empty, a, &Ring::mul), std::runtime_error);
|
||||||
// CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
CHECK_EXCEPTION(apply(a, empty, &Ring::mul), std::runtime_error);
|
||||||
|
|
||||||
// // apply, two nodes, in natural order
|
// apply, two nodes, in natural order
|
||||||
// DT anotb = apply(a, notb, &Ring::mul);
|
DT anotb = apply(a, notb, &Ring::mul);
|
||||||
// LONGS_EQUAL(0,anotb(x00))
|
LONGS_EQUAL(0,anotb(x00))
|
||||||
// LONGS_EQUAL(0,anotb(x01))
|
LONGS_EQUAL(0,anotb(x01))
|
||||||
// LONGS_EQUAL(25,anotb(x10))
|
LONGS_EQUAL(25,anotb(x10))
|
||||||
// LONGS_EQUAL(0,anotb(x11))
|
LONGS_EQUAL(0,anotb(x11))
|
||||||
// DOT(anotb);
|
DOT(anotb);
|
||||||
|
|
||||||
// // check pruning
|
// check pruning
|
||||||
// DT pnotb = apply(p, notb, &Ring::mul);
|
DT pnotb = apply(p, notb, &Ring::mul);
|
||||||
// LONGS_EQUAL(10,pnotb(x00))
|
LONGS_EQUAL(10,pnotb(x00))
|
||||||
// LONGS_EQUAL( 0,pnotb(x01))
|
LONGS_EQUAL( 0,pnotb(x01))
|
||||||
// LONGS_EQUAL(10,pnotb(x10))
|
LONGS_EQUAL(10,pnotb(x10))
|
||||||
// LONGS_EQUAL( 0,pnotb(x11))
|
LONGS_EQUAL( 0,pnotb(x11))
|
||||||
// DOT(pnotb);
|
DOT(pnotb);
|
||||||
|
|
||||||
// // check pruning
|
// check pruning
|
||||||
// DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
|
||||||
// LONGS_EQUAL(0,zeros(x00))
|
LONGS_EQUAL(0,zeros(x00))
|
||||||
// LONGS_EQUAL(0,zeros(x01))
|
LONGS_EQUAL(0,zeros(x01))
|
||||||
// LONGS_EQUAL(0,zeros(x10))
|
LONGS_EQUAL(0,zeros(x10))
|
||||||
// LONGS_EQUAL(0,zeros(x11))
|
LONGS_EQUAL(0,zeros(x11))
|
||||||
// DOT(zeros);
|
DOT(zeros);
|
||||||
|
|
||||||
// // apply, two nodes, in switched order
|
// apply, two nodes, in switched order
|
||||||
// DT notba = apply(a, notb, &Ring::mul);
|
DT notba = apply(a, notb, &Ring::mul);
|
||||||
// LONGS_EQUAL(0,notba(x00))
|
LONGS_EQUAL(0,notba(x00))
|
||||||
// LONGS_EQUAL(0,notba(x01))
|
LONGS_EQUAL(0,notba(x01))
|
||||||
// LONGS_EQUAL(25,notba(x10))
|
LONGS_EQUAL(25,notba(x10))
|
||||||
// LONGS_EQUAL(0,notba(x11))
|
LONGS_EQUAL(0,notba(x11))
|
||||||
// DOT(notba);
|
DOT(notba);
|
||||||
|
|
||||||
// // Test choose 0
|
// Test choose 0
|
||||||
// DT actual0 = notba.choose(A, 0);
|
DT actual0 = notba.choose(A, 0);
|
||||||
// EXPECT(assert_equal(DT(0.0), actual0));
|
EXPECT(assert_equal(DT(0.0), actual0));
|
||||||
// DOT(actual0);
|
DOT(actual0);
|
||||||
|
|
||||||
// // Test choose 1
|
// Test choose 1
|
||||||
// DT actual1 = notba.choose(A, 1);
|
DT actual1 = notba.choose(A, 1);
|
||||||
// EXPECT(assert_equal(DT(B, 25, 0), actual1));
|
EXPECT(assert_equal(DT(B, 25, 0), actual1));
|
||||||
// DOT(actual1);
|
DOT(actual1);
|
||||||
|
|
||||||
// // apply, two nodes at same level
|
// apply, two nodes at same level
|
||||||
// DT a_and_a = apply(a, a, &Ring::mul);
|
DT a_and_a = apply(a, a, &Ring::mul);
|
||||||
// LONGS_EQUAL(0,a_and_a(x00))
|
LONGS_EQUAL(0,a_and_a(x00))
|
||||||
// LONGS_EQUAL(0,a_and_a(x01))
|
LONGS_EQUAL(0,a_and_a(x01))
|
||||||
// LONGS_EQUAL(25,a_and_a(x10))
|
LONGS_EQUAL(25,a_and_a(x10))
|
||||||
// LONGS_EQUAL(25,a_and_a(x11))
|
LONGS_EQUAL(25,a_and_a(x11))
|
||||||
// DOT(a_and_a);
|
DOT(a_and_a);
|
||||||
|
|
||||||
// // create a function on C
|
// create a function on C
|
||||||
// DT c(C, 0, 5);
|
DT c(C, 0, 5);
|
||||||
|
|
||||||
// // and a model assigning stuff to C
|
// and a model assigning stuff to C
|
||||||
// Assignment<string> x101;
|
Assignment<string> x101;
|
||||||
// x101[A] = 1, x101[B] = 0, x101[C] = 1;
|
x101[A] = 1, x101[B] = 0, x101[C] = 1;
|
||||||
|
|
||||||
// // mul notba with C
|
// mul notba with C
|
||||||
// DT notbac = apply(notba, c, &Ring::mul);
|
DT notbac = apply(notba, c, &Ring::mul);
|
||||||
// LONGS_EQUAL(125,notbac(x101))
|
LONGS_EQUAL(125,notbac(x101))
|
||||||
// DOT(notbac);
|
DOT(notbac);
|
||||||
|
|
||||||
// // mul now in different order
|
// mul now in different order
|
||||||
// DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
|
||||||
// LONGS_EQUAL(125,acnotb(x101))
|
LONGS_EQUAL(125,acnotb(x101))
|
||||||
// DOT(acnotb);
|
DOT(acnotb);
|
||||||
// }
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test Conversion
|
// test Conversion
|
||||||
|
|
@ -243,9 +244,6 @@ enum Label {
|
||||||
U, V, X, Y, Z
|
U, V, X, Y, Z
|
||||||
};
|
};
|
||||||
typedef DecisionTree<Label, bool> BDT;
|
typedef DecisionTree<Label, bool> BDT;
|
||||||
bool convert(const int& y) {
|
|
||||||
return y != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(DT, conversion)
|
TEST(DT, conversion)
|
||||||
{
|
{
|
||||||
|
|
@ -259,8 +257,10 @@ TEST(DT, conversion)
|
||||||
map<string, Label> ordering;
|
map<string, Label> ordering;
|
||||||
ordering[A] = X;
|
ordering[A] = X;
|
||||||
ordering[B] = Y;
|
ordering[B] = Y;
|
||||||
std::function<bool(const int&)> op = convert;
|
std::function<bool(const int&)> bool_of_int = [](const int& y) {
|
||||||
BDT f2(f1, ordering, op);
|
return y != 0;
|
||||||
|
};
|
||||||
|
BDT f2(f1, ordering, bool_of_int);
|
||||||
// f1.print("f1");
|
// f1.print("f1");
|
||||||
// f2.print("f2");
|
// f2.print("f2");
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue