Improved documentation and tests

release/4.3a0
Frank Dellaert 2023-06-04 15:40:02 +01:00
parent dafa0076ec
commit febeacd686
4 changed files with 137 additions and 26 deletions

View File

@ -28,9 +28,9 @@
namespace gtsam { namespace gtsam {
/** /**
* Algebraic Decision Trees fix the range to double * An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar.
* TODO: consider eliminating this class altogether? * TODO(dellaert): consider eliminating this class altogether?
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -80,20 +80,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2) AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {} : Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */ /**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2) double y2)
: Base(labelC, y1, y2) {} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) { const std::vector<double>& ys) {
this->root_ = this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) { const std::string& table) {
@ -108,7 +150,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) { : Base(nullptr) {

View File

@ -622,7 +622,7 @@ namespace gtsam {
// B=1 // B=1
// A=0: 3 // A=0: 3
// A=1: 4 // A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root. // exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest. // However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y> template<typename L, typename Y>

View File

@ -37,9 +37,23 @@
namespace gtsam { namespace gtsam {
/** /**
* Decision Tree * @brief a decision tree is a function from assignments to values.
* L = label for variables * @tparam L label for variables
* Y = function range (any algebra), e.g., bool, int, double * @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -132,7 +146,8 @@ namespace gtsam {
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities, /**
* Internal recursive function to create from keys, cardinalities,
* and Y values * and Y values
*/ */
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
@ -163,7 +178,13 @@ namespace gtsam {
/** Create a constant */ /** Create a constant */
explicit DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` /**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */ /** Allow Label+Cardinality for convenience */

View File

@ -71,6 +71,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ************************************************************************** */
// Test char labels and int range
/* ************************************************************************** */
// Create a decision stump one one variable 'a' with values 10 and 20.
TEST(DecisionTree, constructor) {
DecisionTree<char, int> tree('a', 10, 20);
// Evaluate the tree on an assignment to the variable.
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ************************************************************************** */ /* ************************************************************************** */
@ -114,18 +127,47 @@ struct Ring {
static inline int mul(const int& a, const int& b) { return a * b; } static inline int mul(const int& a, const int& b) { return a * b; }
}; };
/* ************************************************************************** */
// Check that creating decision trees respects key order.
TEST(DecisionTree, constructor_order) {
// Create labels
string A("A"), B("B");
const std::vector<int> ys1 = {1, 2, 3, 4};
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
const std::vector<int> ys2 = {1, 3, 2, 4};
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
// Both trees will be the same, tree is order from high to low labels.
// Choice(B)
// 0 Choice(A)
// 0 0 Leaf 1
// 0 1 Leaf 2
// 1 Choice(A)
// 1 0 Leaf 3
// 1 1 Leaf 4
EXPECT(tree2.equals(tree1));
// Check the values are as expected by calling the () operator:
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, example) {
// Create labels // Create labels
string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
// create a value // Create assignments using brace initialization:
Assignment<string> x00, x01, x10, x11; Assignment<string> x00{{A, 0}, {B, 0}};
x00[A] = 0, x00[B] = 0; Assignment<string> x01{{A, 0}, {B, 1}};
x01[A] = 0, x01[B] = 1; Assignment<string> x10{{A, 1}, {B, 0}};
x10[A] = 1, x10[B] = 0; Assignment<string> x11{{A, 1}, {B, 1}};
x11[A] = 1, x11[B] = 1;
// empty // empty
DT empty; DT empty;
@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
StringBoolTree f2(f1, bool_of_int); StringBoolTree f2(f1, bool_of_int);
// Check a value // Check a value
Assignment<string> x00; Assignment<string> x00 {{A, 0}, {B, 0}};
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) {
// Check some values // Check some values
Assignment<Label> x00, x01, x10, x11; Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0; x00 = {{X, 0}, {Y, 0}};
x01[X] = 0, x01[Y] = 1; x01 = {{X, 0}, {Y, 1}};
x10[X] = 1, x10[Y] = 0; x10 = {{X, 1}, {Y, 0}};
x11[X] = 1, x11[Y] = 1; x11 = {{X, 1}, {Y, 1}};
EXPECT(!f2(x00)); EXPECT(!f2(x00));
EXPECT(!f2(x01)); EXPECT(!f2(x01));
EXPECT(f2(x10)); EXPECT(f2(x10));