Straight-up depth-first fold method
# Conflicts: # gtsam/discrete/tests/testDecisionTree.cpprelease/4.3a0
parent
53b4053c20
commit
15850333b4
|
@ -630,6 +630,48 @@ namespace gtsam {
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
template <typename L, typename Y, typename X>
|
||||||
|
struct Fold {
|
||||||
|
std::function<X(const Y&, X)> f;
|
||||||
|
|
||||||
|
/// Construct from folding function
|
||||||
|
Fold(std::function<X(const Y&, X)> f) : f(f) {}
|
||||||
|
|
||||||
|
using NodePtr = typename DecisionTree<L, Y>::NodePtr;
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Do a depth-first fold on the tree rooted at node.
|
||||||
|
*
|
||||||
|
* @param node root of a (sub-) tree, or a leaf.
|
||||||
|
* @param x0 Initial accumulator value.
|
||||||
|
* @return X Final accumulator value.
|
||||||
|
*/
|
||||||
|
X fold(const NodePtr& node, X x0) const {
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) {
|
||||||
|
return f(leaf->constant(), x0);
|
||||||
|
} else if (auto choice =
|
||||||
|
boost::dynamic_pointer_cast<const Choice>(node)) {
|
||||||
|
for (auto&& branch : choice->branches()) x0 = fold(branch, x0);
|
||||||
|
return x0;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Fold: Invalid NodePtr");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// alias for fold:
|
||||||
|
X operator()(const NodePtr& node, X x0) const { return fold(node, x0); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
||||||
|
Fold<L, Y, X> fold(f);
|
||||||
|
return fold(root_, x0);
|
||||||
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
|
|
|
@ -229,6 +229,22 @@ namespace gtsam {
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
const Y& operator()(const Assignment<L>& x) const;
|
const Y& operator()(const Assignment<L>& x) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
|
*
|
||||||
|
* @tparam X type for accumulator.
|
||||||
|
* @param f binary function: Y * X -> X returning an updated accumulator.
|
||||||
|
* @param x0 initial value for accumulator.
|
||||||
|
* @return X final value for accumulator.
|
||||||
|
*
|
||||||
|
* @note X is always passed by value.
|
||||||
|
*/
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X fold(Func f, X x0) const;
|
||||||
|
|
||||||
|
/** Retrieve all labels. */
|
||||||
|
std::vector<L> labels() const { return std::vector<L>(); }
|
||||||
|
|
||||||
/** apply Unary operation "op" to f */
|
/** apply Unary operation "op" to f */
|
||||||
DecisionTree apply(const Unary& op) const;
|
DecisionTree apply(const Unary& op) const;
|
||||||
|
|
||||||
|
|
|
@ -123,8 +123,7 @@ struct Ring {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DT, example)
|
TEST(DecisionTree, example) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
@ -236,8 +235,7 @@ std::function<bool(const int&)> bool_of_int = [](const int& y) {
|
||||||
};
|
};
|
||||||
typedef DecisionTree<string, bool> StringBoolTree;
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
|
|
||||||
TEST(DT, ConvertValuesOnly)
|
TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
@ -260,8 +258,7 @@ enum Label {
|
||||||
};
|
};
|
||||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
TEST(DT, ConvertBoth)
|
TEST(DecisionTree, ConvertBoth) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
@ -288,8 +285,7 @@ TEST(DT, ConvertBoth)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DT, Compose)
|
TEST(DecisionTree, Compose) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
|
@ -314,6 +310,49 @@ TEST(DT, Compose)
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Check we can create a decision tree of containers.
|
||||||
|
TEST(DecisionTree, Containers) {
|
||||||
|
using Container = std::vector<double>;
|
||||||
|
using StringContainerTree = DecisionTree<string, Container>;
|
||||||
|
|
||||||
|
// Check default constructor
|
||||||
|
StringContainerTree tree;
|
||||||
|
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
|
// Check conversion
|
||||||
|
std::function<Container(const int& i)> container_of_int = [](const int& i) {
|
||||||
|
Container c;
|
||||||
|
c.emplace_back(i);
|
||||||
|
return c;
|
||||||
|
};
|
||||||
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test fold.
|
||||||
|
TEST(DecisionTree, fold) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
|
double sum = tree.fold(add, 0.0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test retrieving all labels.
|
||||||
|
TEST_DISABLED(DecisionTree, labels) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
auto labels = tree.labels();
|
||||||
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue