diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 11ecbf183..5e17fa7d6 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -630,6 +630,48 @@ namespace gtsam { return LY::compose(functions.begin(), functions.end(), newLabel); } + /*********************************************************************************/ + template + struct Fold { + std::function f; + + /// Construct from folding function + Fold(std::function f) : f(f) {} + + using NodePtr = typename DecisionTree::NodePtr; + using Choice = typename DecisionTree::Choice; + using Leaf = typename DecisionTree::Leaf; + + /** + * @brief Do a depth-first fold on the tree rooted at node. + * + * @param node root of a (sub-) tree, or a leaf. + * @param x0 Initial accumulator value. + * @return X Final accumulator value. + */ + X fold(const NodePtr& node, X x0) const { + if (auto leaf = boost::dynamic_pointer_cast(node)) { + return f(leaf->constant(), x0); + } else if (auto choice = + boost::dynamic_pointer_cast(node)) { + for (auto&& branch : choice->branches()) x0 = fold(branch, x0); + return x0; + } else { + throw std::invalid_argument("Fold: Invalid NodePtr"); + } + } + + // alias for fold: + X operator()(const NodePtr& node, X x0) const { return fold(node, x0); } + }; + + template + template + X DecisionTree::fold(Func f, X x0) const { + Fold fold(f); + return fold(root_, x0); + } + /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index db8a12a20..4148c6c20 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -229,6 +229,22 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all labels. */ + std::vector labels() const { return std::vector(); } + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 5976ea2d4..64f46b129 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -123,8 +123,7 @@ struct Ring { /* ******************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -236,8 +235,7 @@ std::function bool_of_int = [](const int& y) { }; typedef DecisionTree StringBoolTree; -TEST(DT, ConvertValuesOnly) -{ +TEST(DecisionTree, ConvertValuesOnly) { // Create labels string A("A"), B("B"); @@ -260,8 +258,7 @@ enum Label { }; typedef DecisionTree LabelBoolTree; -TEST(DT, ConvertBoth) -{ +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -288,8 +285,7 @@ TEST(DT, ConvertBoth) /* ******************************************************************************** */ // test Compose expansion -TEST(DT, Compose) -{ +TEST(DecisionTree, Compose) { // Create labels string A("A"), B("B"), C("C"); @@ -314,6 +310,49 @@ TEST(DT, Compose) DOT(f5); } +/* ******************************************************************************** */ +// Check we can create a decision tree of containers. +TEST(DecisionTree, Containers) { + using Container = std::vector; + using StringContainerTree = DecisionTree; + + // Check default constructor + StringContainerTree tree; + + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); + + // Check conversion + std::function container_of_int = [](const int& i) { + Container c; + c.emplace_back(i); + return c; + }; + StringContainerTree converted(stringIntTree, container_of_int); +} + +/* ******************************************************************************** */ +// Test fold. +TEST(DecisionTree, fold) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + auto add = [](const int& y, double x) { return y + x; }; + double sum = tree.fold(add, 0.0); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + +/* ******************************************************************************** */ +// Test retrieving all labels. +TEST_DISABLED(DecisionTree, labels) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + auto labels = tree.labels(); + EXPECT_LONGS_EQUAL(2, labels.size()); +} + /* ************************************************************************* */ int main() { TestResult tr;