From 5de3dc42bd3b52421f1165247243e1ee53b0771c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 3 Jan 2022 22:43:45 -0500 Subject: [PATCH] visit and visitWith --- gtsam/discrete/DecisionTree-inl.h | 104 +++++++++++++++------- gtsam/discrete/DecisionTree.h | 35 +++++++- gtsam/discrete/tests/testDecisionTree.cpp | 26 +++++- 3 files changed, 129 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 5e17fa7d6..a1ba0e4c1 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -631,48 +631,86 @@ namespace gtsam { } /*********************************************************************************/ - template - struct Fold { - std::function f; + // Functor performing depth-first visit without Assignment argument. + template + struct Visit { + using F = std::function; + Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. - /// Construct from folding function - Fold(std::function f) : f(f) {} + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(leaf->constant()); - using NodePtr = typename DecisionTree::NodePtr; - using Choice = typename DecisionTree::Choice; - using Leaf = typename DecisionTree::Leaf; - - /** - * @brief Do a depth-first fold on the tree rooted at node. - * - * @param node root of a (sub-) tree, or a leaf. - * @param x0 Initial accumulator value. - * @return X Final accumulator value. - */ - X fold(const NodePtr& node, X x0) const { - if (auto leaf = boost::dynamic_pointer_cast(node)) { - return f(leaf->constant(), x0); - } else if (auto choice = - boost::dynamic_pointer_cast(node)) { - for (auto&& branch : choice->branches()) x0 = fold(branch, x0); - return x0; - } else { - throw std::invalid_argument("Fold: Invalid NodePtr"); - } + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! } - - // alias for fold: - X operator()(const NodePtr& node, X x0) const { return fold(node, x0); } }; template - template - X DecisionTree::fold(Func f, X x0) const { - Fold fold(f); - return fold(root_, x0); + template + void DecisionTree::visit(Func f) const { + Visit visit(f); + visit(root_); } /*********************************************************************************/ + // Functor performing depth-first visit with Assignment argument. + template + struct VisitWith { + using Choices = Assignment; + using F = std::function; + VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(choices, leaf->constant()); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + for (size_t i = 0; i < choice->nrChoices(); i++) { + choices[choice->label()] = i; // Set assignment for label to i + (*this)(choice->branches()[i]); // recurse! + } + } + }; + + template + template + void DecisionTree::visitWith(Func f) const { + VisitWith visit(f); + visit(root_); + } + + /*********************************************************************************/ + // fold is just done with a visit + template + template + X DecisionTree::fold(Func f, X x0) const { + visit([&](const Y& y) { x0 = f(y, x0); }); + return x0; + } + + /*********************************************************************************/ + // labels is just done with a visit + template + std::set DecisionTree::labels() const { + std::set unique; + auto f = [&](const Assignment& choices, const Y&) { + for (auto&& kv : choices) unique.insert(kv.first); + }; + visitWith(f); + return unique; + } + +/*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, const CompareFunc& compare) const { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 4148c6c20..fbcb59665 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -28,6 +28,7 @@ #include #include #include +#include namespace gtsam { @@ -229,6 +230,32 @@ namespace gtsam { /** evaluate */ const Y& operator()(const Assignment& x) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visit(Func f) const; + + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f side-effect taking an assignment and a value. + * + * Example: + * int sum = 0; + * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitWith(Func f) const; + /** * @brief Fold a binary function over the tree, returning accumulator. * @@ -238,12 +265,16 @@ namespace gtsam { * @return X final value for accumulator. * * @note X is always passed by value. + * + * Example: + * auto add = [](const double& y, double x) { return y + x; }; + * double sum = tree.fold(add, 0.0); */ template X fold(Func f, X x0) const; - /** Retrieve all labels. */ - std::vector labels() const { return std::vector(); } + /** Retrieve all unique labels as a set. */ + std::set labels() const; /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 64f46b129..84b9ca5fa 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -332,6 +332,30 @@ TEST(DecisionTree, Containers) { StringContainerTree converted(stringIntTree, container_of_int); } +/* ******************************************************************************** */ +// Test visit. +TEST(DecisionTree, visit) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + double sum = 0.0; + auto visitor = [&](int y) { sum += y; }; + tree.visit(visitor); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + +/* ******************************************************************************** */ +// Test visit, with Choices argument. +TEST(DecisionTree, visitWith) { + // Create small two-level tree + string A("A"), B("B"), C("C"); + DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + double sum = 0.0; + auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + tree.visitWith(visitor); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); +} + /* ******************************************************************************** */ // Test fold. TEST(DecisionTree, fold) { @@ -345,7 +369,7 @@ TEST(DecisionTree, fold) { /* ******************************************************************************** */ // Test retrieving all labels. -TEST_DISABLED(DecisionTree, labels) { +TEST(DecisionTree, labels) { // Create small two-level tree string A("A"), B("B"), C("C"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));