diff --git a/gtsam/base/treeTraversal-inst.h b/gtsam/base/treeTraversal-inst.h index 44f30a01d..e1053a736 100644 --- a/gtsam/base/treeTraversal-inst.h +++ b/gtsam/base/treeTraversal-inst.h @@ -48,47 +48,92 @@ namespace gtsam { expanded(false), treeNode(_treeNode), parentData(_parentData) {} }; - /// Do nothing - default argument for post-visitor for tree traversal + // Do nothing - default argument for post-visitor for tree traversal template void no_op(const boost::shared_ptr& node, const DATA& data) {} + // Internal node used in parallel traversal stack + template + struct ParallelTraversalNode { + const boost::shared_ptr& treeNode; + DATA myData; + ParallelTraversalNode(const boost::shared_ptr& treeNode, const DATA& myData) : + treeNode(treeNode), myData(myData) {} + }; + template class PreOrderTask : public tbb::task { public: const boost::shared_ptr& treeNode; - DATA& myData; + DATA myData; VISITOR_PRE& visitorPre; VISITOR_POST& visitorPost; - PreOrderTask(const boost::shared_ptr& treeNode, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) : - treeNode(treeNode), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost) {} + int problemSizeThreshold; + PreOrderTask(const boost::shared_ptr& treeNode, const DATA& myData, + VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold) : + treeNode(treeNode), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost), + problemSizeThreshold(problemSizeThreshold) {} + + typedef ParallelTraversalNode ParallelTraversalNode; tbb::task* execute() { - // Set TBB ref count - set_ref_count(1 + (int)treeNode->children.size()); - // Create data and tasks for our children - std::vector childData; - childData.reserve(treeNode->children.size()); - std::vector tasks; - tasks.reserve(treeNode->children.size()); - BOOST_FOREACH(const boost::shared_ptr& child, treeNode->children) - { - childData.push_back(visitorPre(child, myData)); - tasks.push_back(new(allocate_child()) PreOrderTask(child, childData.back(), visitorPre, visitorPost)); - } - // Spawn tasks - BOOST_FOREACH(PreOrderTask* task, tasks) - spawn(*task); - // Wait for tasks to finish - wait_for_all(); + // Shared data + int problemSize = 0; - // Now that the children are finished, run the post-order visitor - (void) visitorPost(treeNode, myData); + //std::cout << "New task: " << std::endl; + //BOOST_FOREACH(Key j, treeNode->keys) + // std::cout << j << " "; + //std::cout << std::endl; + + // Process this node and its children + processNode(treeNode, myData, problemSize); // Return NULL return NULL; } + + void processNode(const boost::shared_ptr& node, DATA& myData, int& problemSize) + { + tbb::task_list childTasks; + int nChildTasks = 0; + + // Increment problem size for this node + problemSize += node->problemSize(); + + // Visit children until problem size exceeds a threshold, then spawn a new task + BOOST_FOREACH(const boost::shared_ptr& child, node->children) + { + if(problemSize < problemSizeThreshold) + { + //std::cout << "problemSize = " << problemSize << std::endl; + //BOOST_FOREACH(Key j, child->keys) + // std::cout << j << " "; + //std::cout << std::endl; + // Process child sequentially (recursive call will increase problem size for children + DATA childData = visitorPre(child, myData); + processNode(child, childData, problemSize); + } + else + { + // Process child in a subtask + childTasks.push_back(*new(allocate_child()) + PreOrderTask(child, visitorPre(child, myData), visitorPre, visitorPost, problemSizeThreshold)); + ++ nChildTasks; + } + } + + // If we have child tasks, start subtasks and wait for them to complete + if(nChildTasks > 0) { + set_ref_count(1 + nChildTasks); + spawn(childTasks); + wait_for_all(); + } + + // Run the post-order visitor + (void) visitorPost(node, myData); + } }; template @@ -99,8 +144,11 @@ namespace gtsam { DATA& myData; VISITOR_PRE& visitorPre; VISITOR_POST& visitorPost; - RootTask(const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) : - roots(roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost) {} + int problemSizeThreshold; + RootTask(const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, + int problemSizeThreshold) : + roots(roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost), + problemSizeThreshold(problemSizeThreshold) {} tbb::task* execute() { @@ -108,14 +156,12 @@ namespace gtsam { // Set TBB ref count set_ref_count(1 + (int)roots.size()); // Create data and tasks for our children - std::vector rootData; - rootData.reserve(roots.size()); std::vector tasks; tasks.reserve(roots.size()); BOOST_FOREACH(const boost::shared_ptr& root, roots) { - rootData.push_back(visitorPre(root, myData)); - tasks.push_back(new(allocate_child()) PreOrderTask(root, rootData.back(), visitorPre, visitorPost)); + tasks.push_back(new(allocate_child()) + PreOrderTask(root, visitorPre(root, myData), visitorPre, visitorPost, problemSizeThreshold)); } // Spawn tasks BOOST_FOREACH(PreOrderTask* task, tasks) @@ -129,11 +175,19 @@ namespace gtsam { template RootTask& - CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) + CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, + int problemSizeThreshold) { typedef RootTask RootTask; - return *new(tbb::task::allocate_root()) RootTask(roots, rootData, visitorPre, visitorPost); + return *new(tbb::task::allocate_root()) RootTask(roots, rootData, visitorPre, visitorPost, problemSizeThreshold); } + + /* ************************************************************************* */ + //template + //struct ParallelDFSData { + // DATA myData; + // FastList >& + //}; } /** Traverse a forest depth-first with pre-order and post-order visits. @@ -227,13 +281,15 @@ namespace gtsam { * @param rootData The data to pass by reference to \c visitorPre when it is called on each * root node. */ template - void DepthFirstForestParallel(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost) + void DepthFirstForestParallel(FOREST& forest, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, + int problemSizeThreshold = 50) { // Typedefs typedef typename FOREST::Node Node; typedef boost::shared_ptr sharedNode; - tbb::task::spawn_root_and_wait(CreateRootTask(forest.roots(), rootData, visitorPre, visitorPost)); + tbb::task::spawn_root_and_wait(CreateRootTask( + forest.roots(), rootData, visitorPre, visitorPost, problemSizeThreshold)); }