Schedule Task in ThreadPool (#1113)

This introduces scheduling of tasks with dependencies to ThreadPool.
Contrary to work items, tasks are only queued for execution after all
their dependencies have completed.
master
gaschler 2018-04-25 16:12:30 +02:00 committed by Wally B. Feed
parent 89ac5cbabf
commit 91fda93757
7 changed files with 231 additions and 46 deletions

View File

@ -21,6 +21,9 @@
#include <chrono>
#include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "cartographer/common/time.h"
#include "glog/logging.h"
namespace cartographer {
@ -35,16 +38,40 @@ ThreadPoolForTesting::~ThreadPoolForTesting() {
MutexLocker locker(&mutex_);
CHECK(running_);
running_ = false;
CHECK_EQ(work_queue_.size(), 0);
CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
}
thread_.join();
}
void ThreadPoolForTesting::Schedule(const std::function<void()> &work_item) {
void ThreadPoolForTesting::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_);
work_queue_.push_back(work_item);
auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPoolForTesting::Schedule(const std::function<void()>& work_item) {
auto task = common::make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPoolForTesting::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "ScheduleWhenReady called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
}
void ThreadPoolForTesting::WaitUntilIdle() {
@ -61,25 +88,25 @@ void ThreadPoolForTesting::WaitUntilIdle() {
void ThreadPoolForTesting::DoWork() {
for (;;) {
std::function<void()> work_item;
std::shared_ptr<Task> task;
{
MutexLocker locker(&mutex_);
locker.AwaitWithTimeout(
[this]()
REQUIRES(mutex_) { return !work_queue_.empty() || !running_; },
REQUIRES(mutex_) { return !task_queue_.empty() || !running_; },
common::FromSeconds(0.1));
if (!work_queue_.empty()) {
work_item = work_queue_.front();
work_queue_.pop_front();
if (!task_queue_.empty()) {
task = task_queue_.front();
task_queue_.pop_front();
}
if (!running_) {
return;
}
if (work_queue_.empty() && !work_item) {
if (tasks_not_ready_.empty() && task_queue_.empty() && !task) {
idle_ = true;
}
}
if (work_item) work_item();
if (task) Execute(task.get());
}
}

View File

@ -19,8 +19,8 @@
#include <deque>
#include <functional>
#include <map>
#include <thread>
#include <vector>
#include "cartographer/common/mutex.h"
#include "cartographer/common/thread_pool.h"
@ -34,26 +34,24 @@ class ThreadPoolForTesting : public ThreadPoolInterface {
ThreadPoolForTesting();
~ThreadPoolForTesting();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
void Schedule(const std::function<void()>& work_item) override;
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
EXCLUDES(mutex_) override;
void WaitUntilIdle();
private:
void DoWork();
std::thread thread_ GUARDED_BY(mutex_);
Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true;
bool idle_ GUARDED_BY(mutex_) = true;
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
Mutex mutex_;
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
std::map<Task*, std::shared_ptr<Task>> tasks_not_ready_ GUARDED_BY(mutex_);
std::thread thread_ GUARDED_BY(mutex_);
};
} // namespace testing

View File

@ -35,6 +35,7 @@ class Task {
using WorkItem = std::function<void()>;
enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED };
Task() = default;
~Task();
State GetState() EXCLUDES(mutex_);

View File

@ -46,11 +46,13 @@ class FakeThreadPool : public ThreadPoolInterface {
}
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) override {
std::shared_ptr<Task> shared_task;
auto it =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
EXPECT_TRUE(it.second);
SetThreadPool(it.first->first);
return it.first->second;
shared_task = it.first->second;
SetThreadPool(shared_task.get());
return shared_task;
}
void RunNext() {

View File

@ -21,6 +21,8 @@
#include <chrono>
#include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "glog/logging.h"
namespace cartographer {
@ -44,17 +46,41 @@ ThreadPool::~ThreadPool() {
MutexLocker locker(&mutex_);
CHECK(running_);
running_ = false;
CHECK_EQ(work_queue_.size(), 0);
CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
}
for (std::thread& thread : pool_) {
thread.join();
}
}
void ThreadPool::Schedule(const std::function<void()>& work_item) {
void ThreadPool::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_);
CHECK(running_);
work_queue_.push_back(work_item);
auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPool::Schedule(const std::function<void()>& work_item) {
auto task = make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPool::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "Schedule called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
}
void ThreadPool::DoWork() {
@ -65,21 +91,22 @@ void ThreadPool::DoWork() {
CHECK_NE(nice(10), -1);
#endif
for (;;) {
std::function<void()> work_item;
std::shared_ptr<Task> task;
{
MutexLocker locker(&mutex_);
locker.Await([this]() REQUIRES(mutex_) {
return !work_queue_.empty() || !running_;
return !task_queue_.empty() || !running_;
});
if (!work_queue_.empty()) {
work_item = work_queue_.front();
work_queue_.pop_front();
if (!task_queue_.empty()) {
task = std::move(task_queue_.front());
task_queue_.pop_front();
} else if (!running_) {
return;
}
}
CHECK(work_item);
work_item();
CHECK(task);
CHECK_EQ(task->GetState(), common::Task::DEPENDENCIES_COMPLETED);
Execute(task.get());
}
}

View File

@ -19,9 +19,9 @@
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <thread>
#include <unordered_map>
#include <vector>
#include "cartographer/common/mutex.h"
@ -50,11 +50,12 @@ class ThreadPoolInterface {
virtual void NotifyDependenciesCompleted(Task* task) = 0;
};
// A fixed number of threads working on a work queue of work items. Adding a
// new work item does not block, and will be executed by a background thread
// eventually. The queue must be empty before calling the destructor. The thread
// pool will then wait for the currently executing work items to finish and then
// destroy the threads.
// A fixed number of threads working on tasks. Adding a task does not block.
// Tasks may be added whether or not their dependencies are completed.
// When all dependencies of a task are completed, it is queued up for execution
// in a background thread. The queue must be empty before calling the
// destructor. The thread pool will then wait for the currently executing work
// items to finish and then destroy the threads.
class ThreadPool : public ThreadPoolInterface {
public:
explicit ThreadPool(int num_threads);
@ -69,21 +70,19 @@ class ThreadPool : public ThreadPoolInterface {
// When the returned weak pointer is expired, 'task' has certainly completed,
// so dependants no longer need to add it as a dependency.
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
EXCLUDES(mutex_) override;
private:
void DoWork();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true;
std::vector<std::thread> pool_ GUARDED_BY(mutex_);
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
std::unordered_map<Task*, std::shared_ptr<Task>> tasks_not_ready_
GUARDED_BY(mutex_);
};
} // namespace common

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cartographer/common/thread_pool.h"
#include <vector>
#include "cartographer/common/make_unique.h"
#include "gtest/gtest.h"
namespace cartographer {
namespace common {
namespace {
class Receiver {
public:
void Receive(int number) {
Mutex::Locker locker(&mutex_);
received_numbers_.push_back(number);
}
void WaitForNumberSequence(const std::vector<int>& expected_numbers) {
bool have_enough_numbers = false;
while (!have_enough_numbers) {
common::MutexLocker locker(&mutex_);
have_enough_numbers = locker.AwaitWithTimeout(
[this, &expected_numbers]() REQUIRES(mutex_) {
return (received_numbers_.size() >= expected_numbers.size());
},
common::FromSeconds(0.1));
}
EXPECT_EQ(expected_numbers, received_numbers_);
}
std::vector<int> received_numbers_;
Mutex mutex_;
};
TEST(ThreadPoolTest, RunTask) {
ThreadPool pool(1);
Receiver receiver;
auto task = common::make_unique<Task>();
task->SetWorkItem([&receiver]() { receiver.Receive(1); });
pool.Schedule(std::move(task));
receiver.WaitForNumberSequence({1});
}
TEST(ThreadPoolTest, RunWithDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithOutOfScopeDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
{
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
}
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithMultipleDependencies) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2a = common::make_unique<Task>();
task_2a->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_2b = common::make_unique<Task>();
task_2b->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_3 = common::make_unique<Task>();
task_3->SetWorkItem([&receiver]() { receiver.Receive(3); });
/* -> task_2a \
* task_1 /-> task_2b --> task_3
*/
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2a->AddDependency(weak_task_1);
auto weak_task_2a = pool.Schedule(std::move(task_2a));
task_3->AddDependency(weak_task_1);
task_3->AddDependency(weak_task_2a);
task_2b->AddDependency(weak_task_1);
auto weak_task_2b = pool.Schedule(std::move(task_2b));
task_3->AddDependency(weak_task_2b);
pool.Schedule(std::move(task_3));
receiver.WaitForNumberSequence({1, 2, 2, 3});
}
TEST(ThreadPoolTest, RunWithFinishedDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
receiver.WaitForNumberSequence({1});
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
} // namespace
} // namespace common
} // namespace cartographer