diff --git a/gtsam/base/treeTraversal-inst.h b/gtsam/base/treeTraversal-inst.h index 87afa6187..44f30a01d 100644 --- a/gtsam/base/treeTraversal-inst.h +++ b/gtsam/base/treeTraversal-inst.h @@ -26,6 +26,10 @@ #include #include +#include +#undef max // TBB seems to include windows.h and we don't want these macros +#undef min + namespace gtsam { /** Internal functions used for traversing trees */ @@ -47,6 +51,89 @@ namespace gtsam { /// Do nothing - default argument for post-visitor for tree traversal template void no_op(const boost::shared_ptr& node, const DATA& data) {} + + template + class PreOrderTask : public tbb::task + { + public: + const boost::shared_ptr& treeNode; + DATA& myData; + VISITOR_PRE& visitorPre; + VISITOR_POST& visitorPost; + PreOrderTask(const boost::shared_ptr& treeNode, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) : + treeNode(treeNode), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost) {} + + tbb::task* execute() + { + // Set TBB ref count + set_ref_count(1 + (int)treeNode->children.size()); + // Create data and tasks for our children + std::vector childData; + childData.reserve(treeNode->children.size()); + std::vector tasks; + tasks.reserve(treeNode->children.size()); + BOOST_FOREACH(const boost::shared_ptr& child, treeNode->children) + { + childData.push_back(visitorPre(child, myData)); + tasks.push_back(new(allocate_child()) PreOrderTask(child, childData.back(), visitorPre, visitorPost)); + } + // Spawn tasks + BOOST_FOREACH(PreOrderTask* task, tasks) + spawn(*task); + // Wait for tasks to finish + wait_for_all(); + + // Now that the children are finished, run the post-order visitor + (void) visitorPost(treeNode, myData); + + // Return NULL + return NULL; + } + }; + + template + class RootTask : public tbb::task + { + public: + const ROOTS& roots; + DATA& myData; + VISITOR_PRE& visitorPre; + VISITOR_POST& visitorPost; + RootTask(const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) : + roots(roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost) {} + + tbb::task* execute() + { + typedef PreOrderTask PreOrderTask; + // Set TBB ref count + set_ref_count(1 + (int)roots.size()); + // Create data and tasks for our children + std::vector rootData; + rootData.reserve(roots.size()); + std::vector tasks; + tasks.reserve(roots.size()); + BOOST_FOREACH(const boost::shared_ptr& root, roots) + { + rootData.push_back(visitorPre(root, myData)); + tasks.push_back(new(allocate_child()) PreOrderTask(root, rootData.back(), visitorPre, visitorPost)); + } + // Spawn tasks + BOOST_FOREACH(PreOrderTask* task, tasks) + spawn(*task); + // Wait for tasks to finish + wait_for_all(); + // Return NULL + return NULL; + } + }; + + template + RootTask& + CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) + { + typedef RootTask RootTask; + return *new(tbb::task::allocate_root()) RootTask(roots, rootData, visitorPre, visitorPost); + } } /** Traverse a forest depth-first with pre-order and post-order visits. @@ -124,7 +211,31 @@ namespace gtsam { { DepthFirstForest(forest, rootData, visitorPre, no_op); } - + + /** Traverse a forest depth-first with pre-order and post-order visits. + * @param forest The forest of trees to traverse. The method \c forest.roots() should exist + * and return a collection of (shared) pointers to \c FOREST::Node. + * @param visitorPre \c visitorPre(node, parentData) will be called at every node, before + * visiting its children, and will be passed, by reference, the \c DATA object returned + * by the visit to its parent. Likewise, \c visitorPre should return the \c DATA object + * to pass to the children. The returned \c DATA object will be copy-constructed only + * upon returning to store internally, thus may be modified by visiting the children. + * Regarding efficiency, this copy-on-return is usually optimized out by the compiler. + * @param visitorPost \c visitorPost(node, data) will be called at every node, after visiting + * its children, and will be passed, by reference, the \c DATA object returned by the + * call to \c visitorPre (the \c DATA object may be modified by visiting the children). + * @param rootData The data to pass by reference to \c visitorPre when it is called on each + * root node. */ + template + void DepthFirstForestParallel(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) + { + // Typedefs + typedef typename FOREST::Node Node; + typedef boost::shared_ptr sharedNode; + + tbb::task::spawn_root_and_wait(CreateRootTask(forest.roots(), rootData, visitorPre, visitorPost)); + } + /* ************************************************************************* */ /** Traversal function for CloneForest */