From 15850333b48ce2c1e05ea192ae29201200fb08dc Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 3 Jan 2022 17:28:10 -0500 Subject: [PATCH 1/6] Straight-up depth-first fold method # Conflicts: # gtsam/discrete/tests/testDecisionTree.cpp --- gtsam/discrete/DecisionTree-inl.h | 42 +++++++++++++++++ gtsam/discrete/DecisionTree.h | 16 +++++++ gtsam/discrete/tests/testDecisionTree.cpp | 55 +++++++++++++++++++---- 3 files changed, 105 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 11ecbf183..5e17fa7d6 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -630,6 +630,48 @@ namespace gtsam { return LY::compose(functions.begin(), functions.end(), newLabel); } + /*********************************************************************************/ + template + struct Fold { + std::function f; + + /// Construct from folding function + Fold(std::function f) : f(f) {} + + using NodePtr = typename DecisionTree::NodePtr; + using Choice = typename DecisionTree::Choice; + using Leaf = typename DecisionTree::Leaf; + + /** + * @brief Do a depth-first fold on the tree rooted at node. + * + * @param node root of a (sub-) tree, or a leaf. + * @param x0 Initial accumulator value. + * @return X Final accumulator value. + */ + X fold(const NodePtr& node, X x0) const { + if (auto leaf = boost::dynamic_pointer_cast(node)) { + return f(leaf->constant(), x0); + } else if (auto choice = + boost::dynamic_pointer_cast(node)) { + for (auto&& branch : choice->branches()) x0 = fold(branch, x0); + return x0; + } else { + throw std::invalid_argument("Fold: Invalid NodePtr"); + } + } + + // alias for fold: + X operator()(const NodePtr& node, X x0) const { return fold(node, x0); } + }; + + template + template + X DecisionTree::fold(Func f, X x0) const { + Fold fold(f); + return fold(root_, x0); + } + /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index db8a12a20..4148c6c20 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -229,6 +229,22 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Fold a binary function over the tree, returning accumulator. + * + * @tparam X type for accumulator. + * @param f binary function: Y * X -> X returning an updated accumulator. + * @param x0 initial value for accumulator. + * @return X final value for accumulator. + * + * @note X is always passed by value. + */ + template + X fold(Func f, X x0) const; + + /** Retrieve all labels. */ + std::vector labels() const { return std::vector(); } + /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 5976ea2d4..64f46b129 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -123,8 +123,7 @@ struct Ring { /* ******************************************************************************** */ // test DT -TEST(DT, example) -{ +TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); @@ -236,8 +235,7 @@ std::function bool_of_int = [](const int& y) { }; typedef DecisionTree StringBoolTree; -TEST(DT, ConvertValuesOnly) -{ +TEST(DecisionTree, ConvertValuesOnly) { // Create labels string A("A"), B("B"); @@ -260,8 +258,7 @@ enum Label { }; typedef DecisionTree LabelBoolTree; -TEST(DT, ConvertBoth) -{ +TEST(DecisionTree, ConvertBoth) { // Create labels string A("A"), B("B"); @@ -288,8 +285,7 @@ TEST(DT, ConvertBoth) /* ******************************************************************************** */ // test Compose expansion -TEST(DT, Compose) -{ +TEST(DecisionTree, Compose) { // Create labels string A("A"), B("B"), C("C"); @@ -314,6 +310,49 @@ TEST(DT, Compose) DOT(f5); } +/* ******************************************************************************** */ +// Check we can create a decision tree of containers. +TEST(DecisionTree, Containers) { + using Container = std::vector; + using StringContainerTree = DecisionTree; + + // Check default constructor + StringContainerTree tree; + + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); + + // Check conversion + std::function container_of_int = [](const int& i) { + Container c; + c.emplace_back(i); + return c; + }; + StringContainerTree converted(stringIntTree, container_of_int); +} + +/* ******************************************************************************** */ +// Test fold. +TEST(DecisionTree, fold) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + auto add = [](const int& y, double x) { return y + x; }; + double sum = tree.fold(add, 0.0); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + +/* ******************************************************************************** */ +// Test retrieving all labels. +TEST_DISABLED(DecisionTree, labels) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + auto labels = tree.labels(); + EXPECT_LONGS_EQUAL(2, labels.size()); +} + /* ************************************************************************* */ int main() { TestResult tr; From 5de3dc42bd3b52421f1165247243e1ee53b0771c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 3 Jan 2022 22:43:45 -0500 Subject: [PATCH 2/6] visit and visitWith --- gtsam/discrete/DecisionTree-inl.h | 104 +++++++++++++++------- gtsam/discrete/DecisionTree.h | 35 +++++++- gtsam/discrete/tests/testDecisionTree.cpp | 26 +++++- 3 files changed, 129 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 5e17fa7d6..a1ba0e4c1 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -631,48 +631,86 @@ namespace gtsam { } /*********************************************************************************/ - template - struct Fold { - std::function f; + // Functor performing depth-first visit without Assignment argument. + template + struct Visit { + using F = std::function; + Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. - /// Construct from folding function - Fold(std::function f) : f(f) {} + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); - using NodePtr = typename DecisionTree::NodePtr; - using Choice = typename DecisionTree::Choice; - using Leaf = typename DecisionTree::Leaf; - - /** - * @brief Do a depth-first fold on the tree rooted at node. - * - * @param node root of a (sub-) tree, or a leaf. - * @param x0 Initial accumulator value. - * @return X Final accumulator value. - */ - X fold(const NodePtr& node, X x0) const { - if (auto leaf = boost::dynamic_pointer_cast(node)) { - return f(leaf->constant(), x0); - } else if (auto choice = - boost::dynamic_pointer_cast(node)) { - for (auto&& branch : choice->branches()) x0 = fold(branch, x0); - return x0; - } else { - throw std::invalid_argument("Fold: Invalid NodePtr"); - } + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! } - - // alias for fold: - X operator()(const NodePtr& node, X x0) const { return fold(node, x0); } }; template - template - X DecisionTree::fold(Func f, X x0) const { - Fold fold(f); - return fold(root_, x0); + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); } /*********************************************************************************/ + // Functor performing depth-first visit with Assignment argument. + template + struct VisitWith { + using Choices = Assignment; + using F = std::function; + VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(choices, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (size_t i = 0; i < choice->nrChoices(); i++) { + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /*********************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /*********************************************************************************/ + // labels is just done with a visit + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& choices, const Y&) { + for (auto&& kv : choices) unique.insert(kv.first); + }; + visitWith(f); + return unique; + } + +/*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, const CompareFunc& compare) const { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 4148c6c20..fbcb59665 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -28,6 +28,7 @@ #include #include #include +#include namespace gtsam { @@ -229,6 +230,32 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking an assignment and a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + /** * @brief Fold a binary function over the tree, returning accumulator. * @@ -238,12 +265,16 @@ namespace gtsam { * @return X final value for accumulator. * * @note X is always passed by value. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); */ template X fold(Func f, X x0) const; - /** Retrieve all labels. */ - std::vector labels() const { return std::vector(); } + /** Retrieve all unique labels as a set. */ + std::set labels() const; /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 64f46b129..84b9ca5fa 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -332,6 +332,30 @@ TEST(DecisionTree, Containers) { StringContainerTree converted(stringIntTree, container_of_int); } +/* ******************************************************************************** */ +// Test visit. +TEST(DecisionTree, visit) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + double sum = 0.0; + auto visitor = [&](int y) { sum += y; }; + tree.visit(visitor); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + +/* ******************************************************************************** */ +// Test visit, with Choices argument. +TEST(DecisionTree, visitWith) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + double sum = 0.0; + auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + tree.visitWith(visitor); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + /* ******************************************************************************** */ // Test fold. TEST(DecisionTree, fold) { @@ -345,7 +369,7 @@ TEST(DecisionTree, fold) { /* ******************************************************************************** */ // Test retrieving all labels. -TEST_DISABLED(DecisionTree, labels) { +TEST(DecisionTree, labels) { // Create small two-level tree string A("A"), B("B"), C("C"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); From b453152f3eeff2425755073a9fb9803d73bffc1c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 3 Jan 2022 23:44:51 -0500 Subject: [PATCH 3/6] Use template parameter for functions, enables auto # Conflicts: # gtsam/discrete/DecisionTree-inl.h # gtsam/discrete/DecisionTree.h --- gtsam/discrete/DecisionTree-inl.h | 35 ++++++++--------------- gtsam/discrete/DecisionTree.h | 11 ++++--- gtsam/discrete/tests/testDecisionTree.cpp | 8 ++---- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index a1ba0e4c1..a0df966a0 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -82,13 +82,7 @@ namespace gtsam { return compare(this->constant_, other->constant_); } - /** - * @brief Print method. - * - * @param s Prefix string. - * @param labelFormatter Functor to format the labels of type L. - * @param valueFormatter Functor to format the values of type Y. - */ + /** print */ void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; @@ -465,23 +459,20 @@ namespace gtsam { /*********************************************************************************/ template - template + template DecisionTree::DecisionTree(const DecisionTree& other, - std::function Y_of_X) { + Func Y_of_X) { // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; + auto L_of_L = [](const L& label) { return label; }; root_ = convertFrom(other.root_, L_of_L, Y_of_X); } /*********************************************************************************/ template - template + template DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, - std::function Y_of_X) { - std::function L_of_M = [&map](const M& label) -> L { - return map.at(label); - }; + const std::map& map, Func Y_of_X) { + auto L_of_M = [&map](const M& label) -> L { return map.at(label); }; root_ = convertFrom(other.root_, L_of_M, Y_of_X); } @@ -601,18 +592,16 @@ namespace gtsam { const typename DecisionTree::NodePtr& f, std::function L_of_M, std::function Y_of_X) const { - using MX = DecisionTree; - using MXLeaf = typename MX::Leaf; - using MXChoice = typename MX::Choice; - using MXNodePtr = typename MX::NodePtr; using LY = DecisionTree; // ugliness below because apparently we can't have templated virtual functions // If leaf, apply unary conversion "op" and create a unique leaf - auto leaf = boost::dynamic_pointer_cast(f); - if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + using MXLeaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(f)) + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice + using MXChoice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(f); if (!choice) throw std::invalid_argument( "DecisionTree::Convert: Invalid NodePtr"); @@ -623,7 +612,7 @@ namespace gtsam { // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(const MXNodePtr& branch: choice->branches()) { + for(auto && branch: choice->branches()) { LY converted(convertFrom(branch, L_of_M, Y_of_X)); functions += converted; } diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index fbcb59665..9692094e1 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -177,9 +177,8 @@ namespace gtsam { * @param other The DecisionTree to convert from. * @param Y_of_X Functor to convert from value type X to type Y. */ - template - DecisionTree(const DecisionTree& other, - std::function Y_of_X); + template + DecisionTree(const DecisionTree& other, Func Y_of_X); /** * @brief Convert from a different value type X to value type Y, also transate @@ -191,9 +190,9 @@ namespace gtsam { * @param L_of_M Map from label type M to type L. * @param Y_of_X Functor to convert from type X to type Y. */ - template - DecisionTree(const DecisionTree& other, const std::map& L_of_M, - std::function Y_of_X); + template + DecisionTree(const DecisionTree& other, const std::map& map, + Func Y_of_X); /// @} /// @name Testable diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 84b9ca5fa..2e6ec59f7 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -230,9 +230,7 @@ TEST(DecisionTree, example) { /* ******************************************************************************** */ // test Conversion of values -std::function bool_of_int = [](const int& y) { - return y != 0; -}; +bool bool_of_int(const int& y) { return y != 0; }; typedef DecisionTree StringBoolTree; TEST(DecisionTree, ConvertValuesOnly) { @@ -269,7 +267,7 @@ TEST(DecisionTree, ConvertBoth) { map ordering; ordering[A] = X; ordering[B] = Y; - LabelBoolTree f2(f1, ordering, bool_of_int); + LabelBoolTree f2(f1, ordering, &bool_of_int); // Check some values Assignment