Improved documentation and tests
parent
dafa0076ec
commit
febeacd686
|
@ -28,9 +28,9 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Algebraic Decision Trees fix the range to double
|
* An algebraic decision tree fixes the range of a DecisionTree to double.
|
||||||
* Just has some nice constructors and some syntactic sugar
|
* Just has some nice constructors and some syntactic sugar.
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO(dellaert): consider eliminating this class altogether?
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -80,20 +80,62 @@ namespace gtsam {
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
: Base(label, y1, y2) {}
|
: Base(label, y1, y2) {}
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a new leaf function splitting on a variable
|
||||||
|
*
|
||||||
|
* @param labelC: The label with cardinality 2
|
||||||
|
* @param y1: The value for the first key
|
||||||
|
* @param y2: The value for the second key
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* std::pair<string, size_t> A {"a", 2};
|
||||||
|
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
|
||||||
|
* @endcode
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
double y2)
|
double y2)
|
||||||
: Base(labelC, y1, y2) {}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/**
|
||||||
|
* @brief Create from keys with cardinalities and a vector table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param ys: The vector table
|
||||||
|
*
|
||||||
|
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
|
||||||
|
* respectively, and a vector table of size 12:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
|
* const vector<double> cpt{
|
||||||
|
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
|
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
|
||||||
|
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
|
||||||
|
* @endcode
|
||||||
|
* The table is given in the following order:
|
||||||
|
* A=0, B=0, C=0
|
||||||
|
* A=0, B=0, C=1
|
||||||
|
* ...
|
||||||
|
* A=1, B=1, C=1
|
||||||
|
* Hence, the first line in the table is for A==0, and the second for A==1.
|
||||||
|
* In each line, the first two entries are for B==0, the next two for B==1,
|
||||||
|
* and the last two for B==2. Each pair is for a C value of 0 and 1.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::vector<double>& ys) {
|
const std::vector<double>& ys) {
|
||||||
this->root_ =
|
this->root_ =
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/**
|
||||||
|
* @brief Create from keys and string table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param table: The string table, given as a string of doubles.
|
||||||
|
*
|
||||||
|
* @note Table needs to be in same order as the vector table in the other constructor.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
@ -108,7 +150,13 @@ namespace gtsam {
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a range of decision trees, splitting on a single variable.
|
||||||
|
*
|
||||||
|
* @param begin: Iterator to beginning of a range of decision trees
|
||||||
|
* @param end: Iterator to end of a range of decision trees
|
||||||
|
* @param label: The label to split on
|
||||||
|
*/
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
: Base(nullptr) {
|
: Base(nullptr) {
|
||||||
|
|
|
@ -622,7 +622,7 @@ namespace gtsam {
|
||||||
// B=1
|
// B=1
|
||||||
// A=0: 3
|
// A=0: 3
|
||||||
// A=1: 4
|
// A=1: 4
|
||||||
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
|
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
|
||||||
// exactly the same tree as above: the highest label is always the root.
|
// exactly the same tree as above: the highest label is always the root.
|
||||||
// However, it will be *way* faster if labels are given highest to lowest.
|
// However, it will be *way* faster if labels are given highest to lowest.
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
|
|
@ -37,9 +37,23 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decision Tree
|
* @brief a decision tree is a function from assignments to values.
|
||||||
* L = label for variables
|
* @tparam L label for variables
|
||||||
* Y = function range (any algebra), e.g., bool, int, double
|
* @tparam Y function range (any algebra), e.g., bool, int, double
|
||||||
|
*
|
||||||
|
* After creating a decision tree on some variables, the tree can be evaluated
|
||||||
|
* on an assignment to those variables. Example:
|
||||||
|
*
|
||||||
|
* @code{.cpp}
|
||||||
|
* // Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
* DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
*
|
||||||
|
* // Evaluate the tree on an assignment to the variable.
|
||||||
|
* int value0 = tree({{'a', 0}}); // value0 = 10
|
||||||
|
* int value1 = tree({{'a', 1}}); // value1 = 20
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* More examples can be found in testDecisionTree.cpp
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -132,7 +146,8 @@ namespace gtsam {
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/** Internal recursive function to create from keys, cardinalities,
|
/**
|
||||||
|
* Internal recursive function to create from keys, cardinalities,
|
||||||
* and Y values
|
* and Y values
|
||||||
*/
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
|
@ -163,7 +178,13 @@ namespace gtsam {
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
explicit DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
/**
|
||||||
|
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
||||||
|
*
|
||||||
|
* @param label The variable to split on.
|
||||||
|
* @param y1 The value for the first assignment.
|
||||||
|
* @param y2 The value for the second assignment.
|
||||||
|
*/
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
|
||||||
/** Allow Label+Cardinality for convenience */
|
/** Allow Label+Cardinality for convenience */
|
||||||
|
|
|
@ -71,6 +71,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test char labels and int range
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
// Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
TEST(DecisionTree, constructor) {
|
||||||
|
DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
|
||||||
|
// Evaluate the tree on an assignment to the variable.
|
||||||
|
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
@ -114,18 +127,47 @@ struct Ring {
|
||||||
static inline int mul(const int& a, const int& b) { return a * b; }
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Check that creating decision trees respects key order.
|
||||||
|
TEST(DecisionTree, constructor_order) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
const std::vector<int> ys1 = {1, 2, 3, 4};
|
||||||
|
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
|
||||||
|
|
||||||
|
const std::vector<int> ys2 = {1, 3, 2, 4};
|
||||||
|
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
|
||||||
|
|
||||||
|
// Both trees will be the same, tree is order from high to low labels.
|
||||||
|
// Choice(B)
|
||||||
|
// 0 Choice(A)
|
||||||
|
// 0 0 Leaf 1
|
||||||
|
// 0 1 Leaf 2
|
||||||
|
// 1 Choice(A)
|
||||||
|
// 1 0 Leaf 3
|
||||||
|
// 1 1 Leaf 4
|
||||||
|
|
||||||
|
EXPECT(tree2.equals(tree1));
|
||||||
|
|
||||||
|
// Check the values are as expected by calling the () operator:
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST(DecisionTree, example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
// create a value
|
// Create assignments using brace initialization:
|
||||||
Assignment<string> x00, x01, x10, x11;
|
Assignment<string> x00{{A, 0}, {B, 0}};
|
||||||
x00[A] = 0, x00[B] = 0;
|
Assignment<string> x01{{A, 0}, {B, 1}};
|
||||||
x01[A] = 0, x01[B] = 1;
|
Assignment<string> x10{{A, 1}, {B, 0}};
|
||||||
x10[A] = 1, x10[B] = 0;
|
Assignment<string> x11{{A, 1}, {B, 1}};
|
||||||
x11[A] = 1, x11[B] = 1;
|
|
||||||
|
|
||||||
// empty
|
// empty
|
||||||
DT empty;
|
DT empty;
|
||||||
|
@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
StringBoolTree f2(f1, bool_of_int);
|
StringBoolTree f2(f1, bool_of_int);
|
||||||
|
|
||||||
// Check a value
|
// Check a value
|
||||||
Assignment<string> x00;
|
Assignment<string> x00 {{A, 0}, {B, 0}};
|
||||||
x00["A"] = 0, x00["B"] = 0;
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
|
|
||||||
// Check some values
|
// Check some values
|
||||||
Assignment<Label> x00, x01, x10, x11;
|
Assignment<Label> x00, x01, x10, x11;
|
||||||
x00[X] = 0, x00[Y] = 0;
|
x00 = {{X, 0}, {Y, 0}};
|
||||||
x01[X] = 0, x01[Y] = 1;
|
x01 = {{X, 0}, {Y, 1}};
|
||||||
x10[X] = 1, x10[Y] = 0;
|
x10 = {{X, 1}, {Y, 0}};
|
||||||
x11[X] = 1, x11[Y] = 1;
|
x11 = {{X, 1}, {Y, 1}};
|
||||||
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
EXPECT(!f2(x01));
|
EXPECT(!f2(x01));
|
||||||
EXPECT(f2(x10));
|
EXPECT(f2(x10));
|
||||||
|
|
Loading…
Reference in New Issue