visit and visitWith
parent
15850333b4
commit
5de3dc42bd
|
|
@ -631,48 +631,86 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y, typename X>
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
struct Fold {
|
template <typename L, typename Y>
|
||||||
std::function<X(const Y&, X)> f;
|
struct Visit {
|
||||||
|
using F = std::function<void(const Y&)>;
|
||||||
|
Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
/// Construct from folding function
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
Fold(std::function<X(const Y&, X)> f) : f(f) {}
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(leaf->constant());
|
||||||
|
|
||||||
using NodePtr = typename DecisionTree<L, Y>::NodePtr;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Do a depth-first fold on the tree rooted at node.
|
|
||||||
*
|
|
||||||
* @param node root of a (sub-) tree, or a leaf.
|
|
||||||
* @param x0 Initial accumulator value.
|
|
||||||
* @return X Final accumulator value.
|
|
||||||
*/
|
|
||||||
X fold(const NodePtr& node, X x0) const {
|
|
||||||
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) {
|
|
||||||
return f(leaf->constant(), x0);
|
|
||||||
} else if (auto choice =
|
|
||||||
boost::dynamic_pointer_cast<const Choice>(node)) {
|
|
||||||
for (auto&& branch : choice->branches()) x0 = fold(branch, x0);
|
|
||||||
return x0;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument("Fold: Invalid NodePtr");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// alias for fold:
|
|
||||||
X operator()(const NodePtr& node, X x0) const { return fold(node, x0); }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Func, typename X>
|
template <typename Func>
|
||||||
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
void DecisionTree<L, Y>::visit(Func f) const {
|
||||||
Fold<L, Y, X> fold(f);
|
Visit<L, Y> visit(f);
|
||||||
return fold(root_, x0);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/*********************************************************************************/
|
||||||
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
|
template <typename L, typename Y>
|
||||||
|
struct VisitWith {
|
||||||
|
using Choices = Assignment<L>;
|
||||||
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
|
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(choices, leaf->constant());
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
|
choices[choice->label()] = i; // Set assignment for label to i
|
||||||
|
(*this)(choice->branches()[i]); // recurse!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visitWith(Func f) const {
|
||||||
|
VisitWith<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
// fold is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func, typename X>
|
||||||
|
X DecisionTree<L, Y>::fold(Func f, X x0) const {
|
||||||
|
visit([&](const Y& y) { x0 = f(y, x0); });
|
||||||
|
return x0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
|
// labels is just done with a visit
|
||||||
|
template <typename L, typename Y>
|
||||||
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
std::set<L> unique;
|
||||||
|
auto f = [&](const Assignment<L>& choices, const Y&) {
|
||||||
|
for (auto&& kv : choices) unique.insert(kv.first);
|
||||||
|
};
|
||||||
|
visitWith(f);
|
||||||
|
return unique;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -229,6 +230,32 @@ namespace gtsam {
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
const Y& operator()(const Assignment<L>& x) const;
|
const Y& operator()(const Assignment<L>& x) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking a value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visit(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f side-effect taking an assignment and a value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visitWith(Func f) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Fold a binary function over the tree, returning accumulator.
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
*
|
*
|
||||||
|
|
@ -238,12 +265,16 @@ namespace gtsam {
|
||||||
* @return X final value for accumulator.
|
* @return X final value for accumulator.
|
||||||
*
|
*
|
||||||
* @note X is always passed by value.
|
* @note X is always passed by value.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* auto add = [](const double& y, double x) { return y + x; };
|
||||||
|
* double sum = tree.fold(add, 0.0);
|
||||||
*/
|
*/
|
||||||
template <typename Func, typename X>
|
template <typename Func, typename X>
|
||||||
X fold(Func f, X x0) const;
|
X fold(Func f, X x0) const;
|
||||||
|
|
||||||
/** Retrieve all labels. */
|
/** Retrieve all unique labels as a set. */
|
||||||
std::vector<L> labels() const { return std::vector<L>(); }
|
std::set<L> labels() const;
|
||||||
|
|
||||||
/** apply Unary operation "op" to f */
|
/** apply Unary operation "op" to f */
|
||||||
DecisionTree apply(const Unary& op) const;
|
DecisionTree apply(const Unary& op) const;
|
||||||
|
|
|
||||||
|
|
@ -332,6 +332,30 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree converted(stringIntTree, container_of_int);
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test visit.
|
||||||
|
TEST(DecisionTree, visit) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
tree.visit(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
// Test visit, with Choices argument.
|
||||||
|
TEST(DecisionTree, visitWith) {
|
||||||
|
// Create small two-level tree
|
||||||
|
string A("A"), B("B"), C("C");
|
||||||
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
tree.visitWith(visitor);
|
||||||
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// Test fold.
|
// Test fold.
|
||||||
TEST(DecisionTree, fold) {
|
TEST(DecisionTree, fold) {
|
||||||
|
|
@ -345,7 +369,7 @@ TEST(DecisionTree, fold) {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test retrieving all labels.
|
||||||
TEST_DISABLED(DecisionTree, labels) {
|
TEST(DecisionTree, labels) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue