Schedule Task in ThreadPool (#1113)
This introduces scheduling of tasks with dependencies to ThreadPool. Contrary to work items, tasks are only queued for execution after all their dependencies have completed.master
parent
89ac5cbabf
commit
91fda93757
|
@ -21,6 +21,9 @@
|
|||
#include <chrono>
|
||||
#include <numeric>
|
||||
|
||||
#include "cartographer/common/make_unique.h"
|
||||
#include "cartographer/common/task.h"
|
||||
#include "cartographer/common/time.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace cartographer {
|
||||
|
@ -35,16 +38,40 @@ ThreadPoolForTesting::~ThreadPoolForTesting() {
|
|||
MutexLocker locker(&mutex_);
|
||||
CHECK(running_);
|
||||
running_ = false;
|
||||
CHECK_EQ(work_queue_.size(), 0);
|
||||
CHECK_EQ(task_queue_.size(), 0);
|
||||
CHECK_EQ(tasks_not_ready_.size(), 0);
|
||||
}
|
||||
thread_.join();
|
||||
}
|
||||
|
||||
void ThreadPoolForTesting::Schedule(const std::function<void()> &work_item) {
|
||||
void ThreadPoolForTesting::NotifyDependenciesCompleted(Task* task) {
|
||||
MutexLocker locker(&mutex_);
|
||||
idle_ = false;
|
||||
CHECK(running_);
|
||||
work_queue_.push_back(work_item);
|
||||
auto it = tasks_not_ready_.find(task);
|
||||
CHECK(it != tasks_not_ready_.end());
|
||||
task_queue_.push_back(it->second);
|
||||
tasks_not_ready_.erase(it);
|
||||
}
|
||||
|
||||
void ThreadPoolForTesting::Schedule(const std::function<void()>& work_item) {
|
||||
auto task = common::make_unique<Task>();
|
||||
task->SetWorkItem(work_item);
|
||||
Schedule(std::move(task));
|
||||
}
|
||||
|
||||
std::weak_ptr<Task> ThreadPoolForTesting::Schedule(std::unique_ptr<Task> task) {
|
||||
std::shared_ptr<Task> shared_task;
|
||||
{
|
||||
MutexLocker locker(&mutex_);
|
||||
idle_ = false;
|
||||
CHECK(running_);
|
||||
auto insert_result =
|
||||
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
|
||||
CHECK(insert_result.second) << "ScheduleWhenReady called twice";
|
||||
shared_task = insert_result.first->second;
|
||||
}
|
||||
SetThreadPool(shared_task.get());
|
||||
return shared_task;
|
||||
}
|
||||
|
||||
void ThreadPoolForTesting::WaitUntilIdle() {
|
||||
|
@ -61,25 +88,25 @@ void ThreadPoolForTesting::WaitUntilIdle() {
|
|||
|
||||
void ThreadPoolForTesting::DoWork() {
|
||||
for (;;) {
|
||||
std::function<void()> work_item;
|
||||
std::shared_ptr<Task> task;
|
||||
{
|
||||
MutexLocker locker(&mutex_);
|
||||
locker.AwaitWithTimeout(
|
||||
[this]()
|
||||
REQUIRES(mutex_) { return !work_queue_.empty() || !running_; },
|
||||
REQUIRES(mutex_) { return !task_queue_.empty() || !running_; },
|
||||
common::FromSeconds(0.1));
|
||||
if (!work_queue_.empty()) {
|
||||
work_item = work_queue_.front();
|
||||
work_queue_.pop_front();
|
||||
if (!task_queue_.empty()) {
|
||||
task = task_queue_.front();
|
||||
task_queue_.pop_front();
|
||||
}
|
||||
if (!running_) {
|
||||
return;
|
||||
}
|
||||
if (work_queue_.empty() && !work_item) {
|
||||
if (tasks_not_ready_.empty() && task_queue_.empty() && !task) {
|
||||
idle_ = true;
|
||||
}
|
||||
}
|
||||
if (work_item) work_item();
|
||||
if (task) Execute(task.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "cartographer/common/mutex.h"
|
||||
#include "cartographer/common/thread_pool.h"
|
||||
|
@ -34,26 +34,24 @@ class ThreadPoolForTesting : public ThreadPoolInterface {
|
|||
ThreadPoolForTesting();
|
||||
~ThreadPoolForTesting();
|
||||
|
||||
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
|
||||
LOG(FATAL) << "not implemented";
|
||||
}
|
||||
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
|
||||
|
||||
void Schedule(const std::function<void()>& work_item) override;
|
||||
|
||||
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
|
||||
EXCLUDES(mutex_) override {
|
||||
LOG(FATAL) << "not implemented";
|
||||
}
|
||||
EXCLUDES(mutex_) override;
|
||||
|
||||
void WaitUntilIdle();
|
||||
|
||||
private:
|
||||
void DoWork();
|
||||
|
||||
std::thread thread_ GUARDED_BY(mutex_);
|
||||
Mutex mutex_;
|
||||
bool running_ GUARDED_BY(mutex_) = true;
|
||||
bool idle_ GUARDED_BY(mutex_) = true;
|
||||
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
|
||||
Mutex mutex_;
|
||||
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
|
||||
std::map<Task*, std::shared_ptr<Task>> tasks_not_ready_ GUARDED_BY(mutex_);
|
||||
std::thread thread_ GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace testing
|
||||
|
|
|
@ -35,6 +35,7 @@ class Task {
|
|||
using WorkItem = std::function<void()>;
|
||||
enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED };
|
||||
|
||||
Task() = default;
|
||||
~Task();
|
||||
|
||||
State GetState() EXCLUDES(mutex_);
|
||||
|
|
|
@ -46,11 +46,13 @@ class FakeThreadPool : public ThreadPoolInterface {
|
|||
}
|
||||
|
||||
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) override {
|
||||
std::shared_ptr<Task> shared_task;
|
||||
auto it =
|
||||
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
|
||||
EXPECT_TRUE(it.second);
|
||||
SetThreadPool(it.first->first);
|
||||
return it.first->second;
|
||||
shared_task = it.first->second;
|
||||
SetThreadPool(shared_task.get());
|
||||
return shared_task;
|
||||
}
|
||||
|
||||
void RunNext() {
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <chrono>
|
||||
#include <numeric>
|
||||
|
||||
#include "cartographer/common/make_unique.h"
|
||||
#include "cartographer/common/task.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace cartographer {
|
||||
|
@ -44,17 +46,41 @@ ThreadPool::~ThreadPool() {
|
|||
MutexLocker locker(&mutex_);
|
||||
CHECK(running_);
|
||||
running_ = false;
|
||||
CHECK_EQ(work_queue_.size(), 0);
|
||||
CHECK_EQ(task_queue_.size(), 0);
|
||||
CHECK_EQ(tasks_not_ready_.size(), 0);
|
||||
}
|
||||
for (std::thread& thread : pool_) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::Schedule(const std::function<void()>& work_item) {
|
||||
void ThreadPool::NotifyDependenciesCompleted(Task* task) {
|
||||
MutexLocker locker(&mutex_);
|
||||
CHECK(running_);
|
||||
work_queue_.push_back(work_item);
|
||||
auto it = tasks_not_ready_.find(task);
|
||||
CHECK(it != tasks_not_ready_.end());
|
||||
task_queue_.push_back(it->second);
|
||||
tasks_not_ready_.erase(it);
|
||||
}
|
||||
|
||||
void ThreadPool::Schedule(const std::function<void()>& work_item) {
|
||||
auto task = make_unique<Task>();
|
||||
task->SetWorkItem(work_item);
|
||||
Schedule(std::move(task));
|
||||
}
|
||||
|
||||
std::weak_ptr<Task> ThreadPool::Schedule(std::unique_ptr<Task> task) {
|
||||
std::shared_ptr<Task> shared_task;
|
||||
{
|
||||
MutexLocker locker(&mutex_);
|
||||
CHECK(running_);
|
||||
auto insert_result =
|
||||
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
|
||||
CHECK(insert_result.second) << "Schedule called twice";
|
||||
shared_task = insert_result.first->second;
|
||||
}
|
||||
SetThreadPool(shared_task.get());
|
||||
return shared_task;
|
||||
}
|
||||
|
||||
void ThreadPool::DoWork() {
|
||||
|
@ -65,21 +91,22 @@ void ThreadPool::DoWork() {
|
|||
CHECK_NE(nice(10), -1);
|
||||
#endif
|
||||
for (;;) {
|
||||
std::function<void()> work_item;
|
||||
std::shared_ptr<Task> task;
|
||||
{
|
||||
MutexLocker locker(&mutex_);
|
||||
locker.Await([this]() REQUIRES(mutex_) {
|
||||
return !work_queue_.empty() || !running_;
|
||||
return !task_queue_.empty() || !running_;
|
||||
});
|
||||
if (!work_queue_.empty()) {
|
||||
work_item = work_queue_.front();
|
||||
work_queue_.pop_front();
|
||||
if (!task_queue_.empty()) {
|
||||
task = std::move(task_queue_.front());
|
||||
task_queue_.pop_front();
|
||||
} else if (!running_) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
CHECK(work_item);
|
||||
work_item();
|
||||
CHECK(task);
|
||||
CHECK_EQ(task->GetState(), common::Task::DEPENDENCIES_COMPLETED);
|
||||
Execute(task.get());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "cartographer/common/mutex.h"
|
||||
|
@ -50,11 +50,12 @@ class ThreadPoolInterface {
|
|||
virtual void NotifyDependenciesCompleted(Task* task) = 0;
|
||||
};
|
||||
|
||||
// A fixed number of threads working on a work queue of work items. Adding a
|
||||
// new work item does not block, and will be executed by a background thread
|
||||
// eventually. The queue must be empty before calling the destructor. The thread
|
||||
// pool will then wait for the currently executing work items to finish and then
|
||||
// destroy the threads.
|
||||
// A fixed number of threads working on tasks. Adding a task does not block.
|
||||
// Tasks may be added whether or not their dependencies are completed.
|
||||
// When all dependencies of a task are completed, it is queued up for execution
|
||||
// in a background thread. The queue must be empty before calling the
|
||||
// destructor. The thread pool will then wait for the currently executing work
|
||||
// items to finish and then destroy the threads.
|
||||
class ThreadPool : public ThreadPoolInterface {
|
||||
public:
|
||||
explicit ThreadPool(int num_threads);
|
||||
|
@ -69,21 +70,19 @@ class ThreadPool : public ThreadPoolInterface {
|
|||
// When the returned weak pointer is expired, 'task' has certainly completed,
|
||||
// so dependants no longer need to add it as a dependency.
|
||||
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
|
||||
EXCLUDES(mutex_) override {
|
||||
LOG(FATAL) << "not implemented";
|
||||
}
|
||||
EXCLUDES(mutex_) override;
|
||||
|
||||
private:
|
||||
void DoWork();
|
||||
|
||||
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
|
||||
LOG(FATAL) << "not implemented";
|
||||
}
|
||||
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
|
||||
|
||||
Mutex mutex_;
|
||||
bool running_ GUARDED_BY(mutex_) = true;
|
||||
std::vector<std::thread> pool_ GUARDED_BY(mutex_);
|
||||
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
|
||||
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
|
||||
std::unordered_map<Task*, std::shared_ptr<Task>> tasks_not_ready_
|
||||
GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Copyright 2018 The Cartographer Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cartographer/common/thread_pool.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "cartographer/common/make_unique.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace cartographer {
|
||||
namespace common {
|
||||
namespace {
|
||||
|
||||
class Receiver {
|
||||
public:
|
||||
void Receive(int number) {
|
||||
Mutex::Locker locker(&mutex_);
|
||||
received_numbers_.push_back(number);
|
||||
}
|
||||
|
||||
void WaitForNumberSequence(const std::vector<int>& expected_numbers) {
|
||||
bool have_enough_numbers = false;
|
||||
while (!have_enough_numbers) {
|
||||
common::MutexLocker locker(&mutex_);
|
||||
have_enough_numbers = locker.AwaitWithTimeout(
|
||||
[this, &expected_numbers]() REQUIRES(mutex_) {
|
||||
return (received_numbers_.size() >= expected_numbers.size());
|
||||
},
|
||||
common::FromSeconds(0.1));
|
||||
}
|
||||
EXPECT_EQ(expected_numbers, received_numbers_);
|
||||
}
|
||||
|
||||
std::vector<int> received_numbers_;
|
||||
Mutex mutex_;
|
||||
};
|
||||
|
||||
TEST(ThreadPoolTest, RunTask) {
|
||||
ThreadPool pool(1);
|
||||
Receiver receiver;
|
||||
auto task = common::make_unique<Task>();
|
||||
task->SetWorkItem([&receiver]() { receiver.Receive(1); });
|
||||
pool.Schedule(std::move(task));
|
||||
receiver.WaitForNumberSequence({1});
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, RunWithDependency) {
|
||||
ThreadPool pool(2);
|
||||
Receiver receiver;
|
||||
auto task_2 = common::make_unique<Task>();
|
||||
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
|
||||
auto task_1 = common::make_unique<Task>();
|
||||
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
|
||||
auto weak_task_1 = pool.Schedule(std::move(task_1));
|
||||
task_2->AddDependency(weak_task_1);
|
||||
pool.Schedule(std::move(task_2));
|
||||
receiver.WaitForNumberSequence({1, 2});
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, RunWithOutOfScopeDependency) {
|
||||
ThreadPool pool(2);
|
||||
Receiver receiver;
|
||||
auto task_2 = common::make_unique<Task>();
|
||||
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
|
||||
{
|
||||
auto task_1 = common::make_unique<Task>();
|
||||
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
|
||||
auto weak_task_1 = pool.Schedule(std::move(task_1));
|
||||
task_2->AddDependency(weak_task_1);
|
||||
}
|
||||
pool.Schedule(std::move(task_2));
|
||||
receiver.WaitForNumberSequence({1, 2});
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, RunWithMultipleDependencies) {
|
||||
ThreadPool pool(2);
|
||||
Receiver receiver;
|
||||
auto task_1 = common::make_unique<Task>();
|
||||
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
|
||||
auto task_2a = common::make_unique<Task>();
|
||||
task_2a->SetWorkItem([&receiver]() { receiver.Receive(2); });
|
||||
auto task_2b = common::make_unique<Task>();
|
||||
task_2b->SetWorkItem([&receiver]() { receiver.Receive(2); });
|
||||
auto task_3 = common::make_unique<Task>();
|
||||
task_3->SetWorkItem([&receiver]() { receiver.Receive(3); });
|
||||
/* -> task_2a \
|
||||
* task_1 /-> task_2b --> task_3
|
||||
*/
|
||||
auto weak_task_1 = pool.Schedule(std::move(task_1));
|
||||
task_2a->AddDependency(weak_task_1);
|
||||
auto weak_task_2a = pool.Schedule(std::move(task_2a));
|
||||
task_3->AddDependency(weak_task_1);
|
||||
task_3->AddDependency(weak_task_2a);
|
||||
task_2b->AddDependency(weak_task_1);
|
||||
auto weak_task_2b = pool.Schedule(std::move(task_2b));
|
||||
task_3->AddDependency(weak_task_2b);
|
||||
pool.Schedule(std::move(task_3));
|
||||
receiver.WaitForNumberSequence({1, 2, 2, 3});
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTest, RunWithFinishedDependency) {
|
||||
ThreadPool pool(2);
|
||||
Receiver receiver;
|
||||
auto task_1 = common::make_unique<Task>();
|
||||
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
|
||||
auto task_2 = common::make_unique<Task>();
|
||||
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
|
||||
auto weak_task_1 = pool.Schedule(std::move(task_1));
|
||||
task_2->AddDependency(weak_task_1);
|
||||
receiver.WaitForNumberSequence({1});
|
||||
pool.Schedule(std::move(task_2));
|
||||
receiver.WaitForNumberSequence({1, 2});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace common
|
||||
} // namespace cartographer
|
Loading…
Reference in New Issue