Schedule Task in ThreadPool ()

This introduces scheduling of tasks with dependencies to ThreadPool.
Contrary to work items, tasks are only queued for execution after all
their dependencies have completed.
master
gaschler 2018-04-25 16:12:30 +02:00 committed by Wally B. Feed
parent 89ac5cbabf
commit 91fda93757
7 changed files with 231 additions and 46 deletions

View File

@ -21,6 +21,9 @@
#include <chrono>
#include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "cartographer/common/time.h"
#include "glog/logging.h"
namespace cartographer {
@ -35,16 +38,40 @@ ThreadPoolForTesting::~ThreadPoolForTesting() {
MutexLocker locker(&mutex_);
CHECK(running_);
running_ = false;
CHECK_EQ(work_queue_.size(), 0);
CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
}
thread_.join();
}
void ThreadPoolForTesting::Schedule(const std::function<void()> &work_item) {
void ThreadPoolForTesting::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_);
work_queue_.push_back(work_item);
auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPoolForTesting::Schedule(const std::function<void()>& work_item) {
auto task = common::make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPoolForTesting::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "ScheduleWhenReady called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
}
void ThreadPoolForTesting::WaitUntilIdle() {
@ -61,25 +88,25 @@ void ThreadPoolForTesting::WaitUntilIdle() {
void ThreadPoolForTesting::DoWork() {
for (;;) {
std::function<void()> work_item;
std::shared_ptr<Task> task;
{
MutexLocker locker(&mutex_);
locker.AwaitWithTimeout(
[this]()
REQUIRES(mutex_) { return !work_queue_.empty() || !running_; },
REQUIRES(mutex_) { return !task_queue_.empty() || !running_; },
common::FromSeconds(0.1));
if (!work_queue_.empty()) {
work_item = work_queue_.front();
work_queue_.pop_front();
if (!task_queue_.empty()) {
task = task_queue_.front();
task_queue_.pop_front();
}
if (!running_) {
return;
}
if (work_queue_.empty() && !work_item) {
if (tasks_not_ready_.empty() && task_queue_.empty() && !task) {
idle_ = true;
}
}
if (work_item) work_item();
if (task) Execute(task.get());
}
}

View File

@ -19,8 +19,8 @@
#include <deque>
#include <functional>
#include <map>
#include <thread>
#include <vector>
#include "cartographer/common/mutex.h"
#include "cartographer/common/thread_pool.h"
@ -34,26 +34,24 @@ class ThreadPoolForTesting : public ThreadPoolInterface {
ThreadPoolForTesting();
~ThreadPoolForTesting();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
void Schedule(const std::function<void()>& work_item) override;
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
EXCLUDES(mutex_) override;
void WaitUntilIdle();
private:
void DoWork();
std::thread thread_ GUARDED_BY(mutex_);
Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true;
bool idle_ GUARDED_BY(mutex_) = true;
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
Mutex mutex_;
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
std::map<Task*, std::shared_ptr<Task>> tasks_not_ready_ GUARDED_BY(mutex_);
std::thread thread_ GUARDED_BY(mutex_);
};
} // namespace testing

View File

@ -35,6 +35,7 @@ class Task {
using WorkItem = std::function<void()>;
enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED };
Task() = default;
~Task();
State GetState() EXCLUDES(mutex_);

View File

@ -46,11 +46,13 @@ class FakeThreadPool : public ThreadPoolInterface {
}
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) override {
std::shared_ptr<Task> shared_task;
auto it =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
EXPECT_TRUE(it.second);
SetThreadPool(it.first->first);
return it.first->second;
shared_task = it.first->second;
SetThreadPool(shared_task.get());
return shared_task;
}
void RunNext() {

View File

@ -21,6 +21,8 @@
#include <chrono>
#include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "glog/logging.h"
namespace cartographer {
@ -44,17 +46,41 @@ ThreadPool::~ThreadPool() {
MutexLocker locker(&mutex_);
CHECK(running_);
running_ = false;
CHECK_EQ(work_queue_.size(), 0);
CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
}
for (std::thread& thread : pool_) {
thread.join();
}
}
void ThreadPool::Schedule(const std::function<void()>& work_item) {
void ThreadPool::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_);
CHECK(running_);
work_queue_.push_back(work_item);
auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPool::Schedule(const std::function<void()>& work_item) {
auto task = make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPool::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "Schedule called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
}
void ThreadPool::DoWork() {
@ -65,21 +91,22 @@ void ThreadPool::DoWork() {
CHECK_NE(nice(10), -1);
#endif
for (;;) {
std::function<void()> work_item;
std::shared_ptr<Task> task;
{
MutexLocker locker(&mutex_);
locker.Await([this]() REQUIRES(mutex_) {
return !work_queue_.empty() || !running_;
return !task_queue_.empty() || !running_;
});
if (!work_queue_.empty()) {
work_item = work_queue_.front();
work_queue_.pop_front();
if (!task_queue_.empty()) {
task = std::move(task_queue_.front());
task_queue_.pop_front();
} else if (!running_) {
return;
}
}
CHECK(work_item);
work_item();
CHECK(task);
CHECK_EQ(task->GetState(), common::Task::DEPENDENCIES_COMPLETED);
Execute(task.get());
}
}

View File

@ -19,9 +19,9 @@
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <thread>
#include <unordered_map>
#include <vector>
#include "cartographer/common/mutex.h"
@ -50,11 +50,12 @@ class ThreadPoolInterface {
virtual void NotifyDependenciesCompleted(Task* task) = 0;
};
// A fixed number of threads working on a work queue of work items. Adding a
// new work item does not block, and will be executed by a background thread
// eventually. The queue must be empty before calling the destructor. The thread
// pool will then wait for the currently executing work items to finish and then
// destroy the threads.
// A fixed number of threads working on tasks. Adding a task does not block.
// Tasks may be added whether or not their dependencies are completed.
// When all dependencies of a task are completed, it is queued up for execution
// in a background thread. The queue must be empty before calling the
// destructor. The thread pool will then wait for the currently executing work
// items to finish and then destroy the threads.
class ThreadPool : public ThreadPoolInterface {
public:
explicit ThreadPool(int num_threads);
@ -69,21 +70,19 @@ class ThreadPool : public ThreadPoolInterface {
// When the returned weak pointer is expired, 'task' has certainly completed,
// so dependants no longer need to add it as a dependency.
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
EXCLUDES(mutex_) override;
private:
void DoWork();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override {
LOG(FATAL) << "not implemented";
}
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true;
std::vector<std::thread> pool_ GUARDED_BY(mutex_);
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_);
std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
std::unordered_map<Task*, std::shared_ptr<Task>> tasks_not_ready_
GUARDED_BY(mutex_);
};
} // namespace common

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cartographer/common/thread_pool.h"
#include <vector>
#include "cartographer/common/make_unique.h"
#include "gtest/gtest.h"
namespace cartographer {
namespace common {
namespace {
class Receiver {
public:
void Receive(int number) {
Mutex::Locker locker(&mutex_);
received_numbers_.push_back(number);
}
void WaitForNumberSequence(const std::vector<int>& expected_numbers) {
bool have_enough_numbers = false;
while (!have_enough_numbers) {
common::MutexLocker locker(&mutex_);
have_enough_numbers = locker.AwaitWithTimeout(
[this, &expected_numbers]() REQUIRES(mutex_) {
return (received_numbers_.size() >= expected_numbers.size());
},
common::FromSeconds(0.1));
}
EXPECT_EQ(expected_numbers, received_numbers_);
}
std::vector<int> received_numbers_;
Mutex mutex_;
};
TEST(ThreadPoolTest, RunTask) {
ThreadPool pool(1);
Receiver receiver;
auto task = common::make_unique<Task>();
task->SetWorkItem([&receiver]() { receiver.Receive(1); });
pool.Schedule(std::move(task));
receiver.WaitForNumberSequence({1});
}
TEST(ThreadPoolTest, RunWithDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithOutOfScopeDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
{
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
}
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithMultipleDependencies) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2a = common::make_unique<Task>();
task_2a->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_2b = common::make_unique<Task>();
task_2b->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_3 = common::make_unique<Task>();
task_3->SetWorkItem([&receiver]() { receiver.Receive(3); });
/* -> task_2a \
* task_1 /-> task_2b --> task_3
*/
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2a->AddDependency(weak_task_1);
auto weak_task_2a = pool.Schedule(std::move(task_2a));
task_3->AddDependency(weak_task_1);
task_3->AddDependency(weak_task_2a);
task_2b->AddDependency(weak_task_1);
auto weak_task_2b = pool.Schedule(std::move(task_2b));
task_3->AddDependency(weak_task_2b);
pool.Schedule(std::move(task_3));
receiver.WaitForNumberSequence({1, 2, 2, 3});
}
TEST(ThreadPoolTest, RunWithFinishedDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
receiver.WaitForNumberSequence({1});
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
} // namespace
} // namespace common
} // namespace cartographer