visit and visitWith

release/4.3a0
Frank Dellaert 2022-01-03 22:43:45 -05:00
parent 15850333b4
commit 5de3dc42bd
3 changed files with 129 additions and 36 deletions

View File

@ -631,48 +631,86 @@ namespace gtsam {
} }
/*********************************************************************************/ /*********************************************************************************/
template <typename L, typename Y, typename X> // Functor performing depth-first visit without Assignment<L> argument.
struct Fold { template <typename L, typename Y>
std::function<X(const Y&, X)> f; struct Visit {
using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object.
/// Construct from folding function /// Do a depth-first visit on the tree rooted at node.
Fold(std::function<X(const Y&, X)> f) : f(f) {} void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(leaf->constant());
using NodePtr = typename DecisionTree<L, Y>::NodePtr; using Choice = typename DecisionTree<L, Y>::Choice;
using Choice = typename DecisionTree<L, Y>::Choice; auto choice = boost::dynamic_pointer_cast<const Choice>(node);
using Leaf = typename DecisionTree<L, Y>::Leaf; for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
/**
* @brief Do a depth-first fold on the tree rooted at node.
*
* @param node root of a (sub-) tree, or a leaf.
* @param x0 Initial accumulator value.
* @return X Final accumulator value.
*/
X fold(const NodePtr& node, X x0) const {
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) {
return f(leaf->constant(), x0);
} else if (auto choice =
boost::dynamic_pointer_cast<const Choice>(node)) {
for (auto&& branch : choice->branches()) x0 = fold(branch, x0);
return x0;
} else {
throw std::invalid_argument("Fold: Invalid NodePtr");
}
} }
// alias for fold:
X operator()(const NodePtr& node, X x0) const { return fold(node, x0); }
}; };
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func>
X DecisionTree<L, Y>::fold(Func f, X x0) const { void DecisionTree<L, Y>::visit(Func f) const {
Fold<L, Y, X> fold(f); Visit<L, Y> visit(f);
return fold(root_, x0); visit(root_);
} }
/*********************************************************************************/ /*********************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
}
}
};
template <typename L, typename Y>
template <typename Func>
void DecisionTree<L, Y>::visitWith(Func f) const {
VisitWith<L, Y> visit(f);
visit(root_);
}
/*********************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
template <typename Func, typename X>
X DecisionTree<L, Y>::fold(Func f, X x0) const {
visit([&](const Y& y) { x0 = f(y, x0); });
return x0;
}
/*********************************************************************************/
// labels is just done with a visit
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
};
visitWith(f);
return unique;
}
/*********************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {

View File

@ -28,6 +28,7 @@
#include <map> #include <map>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <set>
namespace gtsam { namespace gtsam {
@ -229,6 +230,32 @@ namespace gtsam {
/** evaluate */ /** evaluate */
const Y& operator()(const Assignment<L>& x) const; const Y& operator()(const Assignment<L>& x) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visit(Func f) const;
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
void visitWith(Func f) const;
/** /**
* @brief Fold a binary function over the tree, returning accumulator. * @brief Fold a binary function over the tree, returning accumulator.
* *
@ -238,12 +265,16 @@ namespace gtsam {
* @return X final value for accumulator. * @return X final value for accumulator.
* *
* @note X is always passed by value. * @note X is always passed by value.
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);
*/ */
template <typename Func, typename X> template <typename Func, typename X>
X fold(Func f, X x0) const; X fold(Func f, X x0) const;
/** Retrieve all labels. */ /** Retrieve all unique labels as a set. */
std::vector<L> labels() const { return std::vector<L>(); } std::set<L> labels() const;
/** apply Unary operation "op" to f */ /** apply Unary operation "op" to f */
DecisionTree apply(const Unary& op) const; DecisionTree apply(const Unary& op) const;

View File

@ -332,6 +332,30 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ******************************************************************************** */
// Test visit.
TEST(DecisionTree, visit) {
// Create small two-level tree
string A("A"), B("B"), C("C");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](int y) { sum += y; };
tree.visit(visitor);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ******************************************************************************** */
// Test visit, with Choices argument.
TEST(DecisionTree, visitWith) {
// Create small two-level tree
string A("A"), B("B"), C("C");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
tree.visitWith(visitor);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}
/* ******************************************************************************** */ /* ******************************************************************************** */
// Test fold. // Test fold.
TEST(DecisionTree, fold) { TEST(DecisionTree, fold) {
@ -345,7 +369,7 @@ TEST(DecisionTree, fold) {
/* ******************************************************************************** */ /* ******************************************************************************** */
// Test retrieving all labels. // Test retrieving all labels.
TEST_DISABLED(DecisionTree, labels) { TEST(DecisionTree, labels) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));