Revert "Revert "Revert "replace deprecated tbb functionality"""

This reverts commit 21279e6d51.
release/4.3a0
Fan Jiang 2021-10-12 19:03:56 -04:00
parent 1a97b14138
commit dae409373c
2 changed files with 46 additions and 29 deletions

View File

@ -158,8 +158,9 @@ void DepthFirstForestParallel(FOREST& forest, DATA& rootData,
// Typedefs
typedef typename FOREST::Node Node;
internal::CreateRootTask<Node>(forest.roots(), rootData, visitorPre,
visitorPost, problemSizeThreshold);
tbb::task::spawn_root_and_wait(
internal::CreateRootTask<Node>(forest.roots(), rootData, visitorPre,
visitorPost, problemSizeThreshold));
#else
DepthFirstForest(forest, rootData, visitorPre, visitorPost);
#endif

View File

@ -22,7 +22,7 @@
#include <boost/make_shared.hpp>
#ifdef GTSAM_USE_TBB
#include <tbb/task_group.h> // tbb::task_group
#include <tbb/task.h> // tbb::task, tbb::task_list
#include <tbb/scalable_allocator.h> // tbb::scalable_allocator
namespace gtsam {
@ -34,7 +34,7 @@ namespace gtsam {
/* ************************************************************************* */
template<typename NODE, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
class PreOrderTask
class PreOrderTask : public tbb::task
{
public:
const boost::shared_ptr<NODE>& treeNode;
@ -42,30 +42,28 @@ namespace gtsam {
VISITOR_PRE& visitorPre;
VISITOR_POST& visitorPost;
int problemSizeThreshold;
tbb::task_group& tg;
bool makeNewTasks;
// Keep track of order phase across multiple calls to the same functor
mutable bool isPostOrderPhase;
bool isPostOrderPhase;
PreOrderTask(const boost::shared_ptr<NODE>& treeNode, const boost::shared_ptr<DATA>& myData,
VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold,
tbb::task_group& tg, bool makeNewTasks = true)
bool makeNewTasks = true)
: treeNode(treeNode),
myData(myData),
visitorPre(visitorPre),
visitorPost(visitorPost),
problemSizeThreshold(problemSizeThreshold),
tg(tg),
makeNewTasks(makeNewTasks),
isPostOrderPhase(false) {}
void operator()() const
tbb::task* execute() override
{
if(isPostOrderPhase)
{
// Run the post-order visitor since this task was recycled to run the post-order visitor
(void) visitorPost(treeNode, *myData);
return nullptr;
}
else
{
@ -73,10 +71,14 @@ namespace gtsam {
{
if(!treeNode->children.empty())
{
// Allocate post-order task as a continuation
isPostOrderPhase = true;
recycle_as_continuation();
bool overThreshold = (treeNode->problemSize() >= problemSizeThreshold);
// If we have child tasks, start subtasks and wait for them to complete
tbb::task_group ctg;
tbb::task* firstChild = 0;
tbb::task_list childTasks;
for(const boost::shared_ptr<NODE>& child: treeNode->children)
{
// Process child in a subtask. Important: Run visitorPre before calling
@ -84,30 +86,37 @@ namespace gtsam {
// allocated an extra child, this causes a TBB error.
boost::shared_ptr<DATA> childData = boost::allocate_shared<DATA>(
tbb::scalable_allocator<DATA>(), visitorPre(child, *myData));
ctg.run(PreOrderTask(child, childData, visitorPre, visitorPost,
problemSizeThreshold, ctg, overThreshold));
tbb::task* childTask =
new (allocate_child()) PreOrderTask(child, childData, visitorPre, visitorPost,
problemSizeThreshold, overThreshold);
if (firstChild)
childTasks.push_back(*childTask);
else
firstChild = childTask;
}
ctg.wait();
// Allocate post-order task as a continuation
isPostOrderPhase = true;
tg.run(*this);
// If we have child tasks, start subtasks and wait for them to complete
set_ref_count((int)treeNode->children.size());
spawn(childTasks);
return firstChild;
}
else
{
// Run the post-order visitor in this task if we have no children
(void) visitorPost(treeNode, *myData);
return nullptr;
}
}
else
{
// Process this node and its children in this task
processNodeRecursively(treeNode, *myData);
return nullptr;
}
}
}
void processNodeRecursively(const boost::shared_ptr<NODE>& node, DATA& myData) const
void processNodeRecursively(const boost::shared_ptr<NODE>& node, DATA& myData)
{
for(const boost::shared_ptr<NODE>& child: node->children)
{
@ -122,7 +131,7 @@ namespace gtsam {
/* ************************************************************************* */
template<typename ROOTS, typename NODE, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
class RootTask
class RootTask : public tbb::task
{
public:
const ROOTS& roots;
@ -130,31 +139,38 @@ namespace gtsam {
VISITOR_PRE& visitorPre;
VISITOR_POST& visitorPost;
int problemSizeThreshold;
tbb::task_group& tg;
RootTask(const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost,
int problemSizeThreshold, tbb::task_group& tg) :
int problemSizeThreshold) :
roots(roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost),
problemSizeThreshold(problemSizeThreshold), tg(tg) {}
problemSizeThreshold(problemSizeThreshold) {}
void operator()() const
tbb::task* execute() override
{
typedef PreOrderTask<NODE, DATA, VISITOR_PRE, VISITOR_POST> PreOrderTask;
// Create data and tasks for our children
tbb::task_list tasks;
for(const boost::shared_ptr<NODE>& root: roots)
{
boost::shared_ptr<DATA> rootData = boost::allocate_shared<DATA>(tbb::scalable_allocator<DATA>(), visitorPre(root, myData));
tg.run(PreOrderTask(root, rootData, visitorPre, visitorPost, problemSizeThreshold, tg));
tasks.push_back(*new(allocate_child())
PreOrderTask(root, rootData, visitorPre, visitorPost, problemSizeThreshold));
}
// Set TBB ref count
set_ref_count(1 + (int) roots.size());
// Spawn tasks
spawn_and_wait_for_all(tasks);
// Return nullptr
return nullptr;
}
};
template<typename NODE, typename ROOTS, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
void CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold)
RootTask<ROOTS, NODE, DATA, VISITOR_PRE, VISITOR_POST>&
CreateRootTask(const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold)
{
typedef RootTask<ROOTS, NODE, DATA, VISITOR_PRE, VISITOR_POST> RootTask;
tbb::task_group tg;
tg.run_and_wait(RootTask(roots, rootData, visitorPre, visitorPost, problemSizeThreshold, tg));
}
return *new(tbb::task::allocate_root()) RootTask(roots, rootData, visitorPre, visitorPost, problemSizeThreshold);
}
}