Schedule Task in ThreadPool (#1113)

This introduces scheduling of tasks with dependencies to ThreadPool.
Contrary to work items, tasks are only queued for execution after all
their dependencies have completed.
master
gaschler 2018-04-25 16:12:30 +02:00 committed by Wally B. Feed
parent 89ac5cbabf
commit 91fda93757
7 changed files with 231 additions and 46 deletions

View File

@ -21,6 +21,9 @@
#include <chrono> #include <chrono>
#include <numeric> #include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "cartographer/common/time.h"
#include "glog/logging.h" #include "glog/logging.h"
namespace cartographer { namespace cartographer {
@ -35,16 +38,40 @@ ThreadPoolForTesting::~ThreadPoolForTesting() {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
CHECK(running_); CHECK(running_);
running_ = false; running_ = false;
CHECK_EQ(work_queue_.size(), 0); CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
} }
thread_.join(); thread_.join();
} }
void ThreadPoolForTesting::Schedule(const std::function<void()> &work_item) { void ThreadPoolForTesting::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_); CHECK(running_);
work_queue_.push_back(work_item); auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPoolForTesting::Schedule(const std::function<void()>& work_item) {
auto task = common::make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPoolForTesting::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
idle_ = false;
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "ScheduleWhenReady called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
} }
void ThreadPoolForTesting::WaitUntilIdle() { void ThreadPoolForTesting::WaitUntilIdle() {
@ -61,25 +88,25 @@ void ThreadPoolForTesting::WaitUntilIdle() {
void ThreadPoolForTesting::DoWork() { void ThreadPoolForTesting::DoWork() {
for (;;) { for (;;) {
std::function<void()> work_item; std::shared_ptr<Task> task;
{ {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
locker.AwaitWithTimeout( locker.AwaitWithTimeout(
[this]() [this]()
REQUIRES(mutex_) { return !work_queue_.empty() || !running_; }, REQUIRES(mutex_) { return !task_queue_.empty() || !running_; },
common::FromSeconds(0.1)); common::FromSeconds(0.1));
if (!work_queue_.empty()) { if (!task_queue_.empty()) {
work_item = work_queue_.front(); task = task_queue_.front();
work_queue_.pop_front(); task_queue_.pop_front();
} }
if (!running_) { if (!running_) {
return; return;
} }
if (work_queue_.empty() && !work_item) { if (tasks_not_ready_.empty() && task_queue_.empty() && !task) {
idle_ = true; idle_ = true;
} }
} }
if (work_item) work_item(); if (task) Execute(task.get());
} }
} }

View File

@ -19,8 +19,8 @@
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <map>
#include <thread> #include <thread>
#include <vector>
#include "cartographer/common/mutex.h" #include "cartographer/common/mutex.h"
#include "cartographer/common/thread_pool.h" #include "cartographer/common/thread_pool.h"
@ -34,26 +34,24 @@ class ThreadPoolForTesting : public ThreadPoolInterface {
ThreadPoolForTesting(); ThreadPoolForTesting();
~ThreadPoolForTesting(); ~ThreadPoolForTesting();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
LOG(FATAL) << "not implemented";
}
void Schedule(const std::function<void()>& work_item) override; void Schedule(const std::function<void()>& work_item) override;
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override { EXCLUDES(mutex_) override;
LOG(FATAL) << "not implemented";
}
void WaitUntilIdle(); void WaitUntilIdle();
private: private:
void DoWork(); void DoWork();
std::thread thread_ GUARDED_BY(mutex_); Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true; bool running_ GUARDED_BY(mutex_) = true;
bool idle_ GUARDED_BY(mutex_) = true; bool idle_ GUARDED_BY(mutex_) = true;
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_); std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
Mutex mutex_; std::map<Task*, std::shared_ptr<Task>> tasks_not_ready_ GUARDED_BY(mutex_);
std::thread thread_ GUARDED_BY(mutex_);
}; };
} // namespace testing } // namespace testing

View File

@ -35,6 +35,7 @@ class Task {
using WorkItem = std::function<void()>; using WorkItem = std::function<void()>;
enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED }; enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED };
Task() = default;
~Task(); ~Task();
State GetState() EXCLUDES(mutex_); State GetState() EXCLUDES(mutex_);

View File

@ -46,11 +46,13 @@ class FakeThreadPool : public ThreadPoolInterface {
} }
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) override { std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) override {
std::shared_ptr<Task> shared_task;
auto it = auto it =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task))); tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
EXPECT_TRUE(it.second); EXPECT_TRUE(it.second);
SetThreadPool(it.first->first); shared_task = it.first->second;
return it.first->second; SetThreadPool(shared_task.get());
return shared_task;
} }
void RunNext() { void RunNext() {

View File

@ -21,6 +21,8 @@
#include <chrono> #include <chrono>
#include <numeric> #include <numeric>
#include "cartographer/common/make_unique.h"
#include "cartographer/common/task.h"
#include "glog/logging.h" #include "glog/logging.h"
namespace cartographer { namespace cartographer {
@ -44,17 +46,41 @@ ThreadPool::~ThreadPool() {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
CHECK(running_); CHECK(running_);
running_ = false; running_ = false;
CHECK_EQ(work_queue_.size(), 0); CHECK_EQ(task_queue_.size(), 0);
CHECK_EQ(tasks_not_ready_.size(), 0);
} }
for (std::thread& thread : pool_) { for (std::thread& thread : pool_) {
thread.join(); thread.join();
} }
} }
void ThreadPool::Schedule(const std::function<void()>& work_item) { void ThreadPool::NotifyDependenciesCompleted(Task* task) {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
CHECK(running_); CHECK(running_);
work_queue_.push_back(work_item); auto it = tasks_not_ready_.find(task);
CHECK(it != tasks_not_ready_.end());
task_queue_.push_back(it->second);
tasks_not_ready_.erase(it);
}
void ThreadPool::Schedule(const std::function<void()>& work_item) {
auto task = make_unique<Task>();
task->SetWorkItem(work_item);
Schedule(std::move(task));
}
std::weak_ptr<Task> ThreadPool::Schedule(std::unique_ptr<Task> task) {
std::shared_ptr<Task> shared_task;
{
MutexLocker locker(&mutex_);
CHECK(running_);
auto insert_result =
tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task)));
CHECK(insert_result.second) << "Schedule called twice";
shared_task = insert_result.first->second;
}
SetThreadPool(shared_task.get());
return shared_task;
} }
void ThreadPool::DoWork() { void ThreadPool::DoWork() {
@ -65,21 +91,22 @@ void ThreadPool::DoWork() {
CHECK_NE(nice(10), -1); CHECK_NE(nice(10), -1);
#endif #endif
for (;;) { for (;;) {
std::function<void()> work_item; std::shared_ptr<Task> task;
{ {
MutexLocker locker(&mutex_); MutexLocker locker(&mutex_);
locker.Await([this]() REQUIRES(mutex_) { locker.Await([this]() REQUIRES(mutex_) {
return !work_queue_.empty() || !running_; return !task_queue_.empty() || !running_;
}); });
if (!work_queue_.empty()) { if (!task_queue_.empty()) {
work_item = work_queue_.front(); task = std::move(task_queue_.front());
work_queue_.pop_front(); task_queue_.pop_front();
} else if (!running_) { } else if (!running_) {
return; return;
} }
} }
CHECK(work_item); CHECK(task);
work_item(); CHECK_EQ(task->GetState(), common::Task::DEPENDENCIES_COMPLETED);
Execute(task.get());
} }
} }

View File

@ -19,9 +19,9 @@
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <map>
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <unordered_map>
#include <vector> #include <vector>
#include "cartographer/common/mutex.h" #include "cartographer/common/mutex.h"
@ -50,11 +50,12 @@ class ThreadPoolInterface {
virtual void NotifyDependenciesCompleted(Task* task) = 0; virtual void NotifyDependenciesCompleted(Task* task) = 0;
}; };
// A fixed number of threads working on a work queue of work items. Adding a // A fixed number of threads working on tasks. Adding a task does not block.
// new work item does not block, and will be executed by a background thread // Tasks may be added whether or not their dependencies are completed.
// eventually. The queue must be empty before calling the destructor. The thread // When all dependencies of a task are completed, it is queued up for execution
// pool will then wait for the currently executing work items to finish and then // in a background thread. The queue must be empty before calling the
// destroy the threads. // destructor. The thread pool will then wait for the currently executing work
// items to finish and then destroy the threads.
class ThreadPool : public ThreadPoolInterface { class ThreadPool : public ThreadPoolInterface {
public: public:
explicit ThreadPool(int num_threads); explicit ThreadPool(int num_threads);
@ -69,21 +70,19 @@ class ThreadPool : public ThreadPoolInterface {
// When the returned weak pointer is expired, 'task' has certainly completed, // When the returned weak pointer is expired, 'task' has certainly completed,
// so dependants no longer need to add it as a dependency. // so dependants no longer need to add it as a dependency.
std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task) std::weak_ptr<Task> Schedule(std::unique_ptr<Task> task)
EXCLUDES(mutex_) override { EXCLUDES(mutex_) override;
LOG(FATAL) << "not implemented";
}
private: private:
void DoWork(); void DoWork();
void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override;
LOG(FATAL) << "not implemented";
}
Mutex mutex_; Mutex mutex_;
bool running_ GUARDED_BY(mutex_) = true; bool running_ GUARDED_BY(mutex_) = true;
std::vector<std::thread> pool_ GUARDED_BY(mutex_); std::vector<std::thread> pool_ GUARDED_BY(mutex_);
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mutex_); std::deque<std::shared_ptr<Task>> task_queue_ GUARDED_BY(mutex_);
std::unordered_map<Task*, std::shared_ptr<Task>> tasks_not_ready_
GUARDED_BY(mutex_);
}; };
} // namespace common } // namespace common

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cartographer/common/thread_pool.h"
#include <vector>
#include "cartographer/common/make_unique.h"
#include "gtest/gtest.h"
namespace cartographer {
namespace common {
namespace {
class Receiver {
public:
void Receive(int number) {
Mutex::Locker locker(&mutex_);
received_numbers_.push_back(number);
}
void WaitForNumberSequence(const std::vector<int>& expected_numbers) {
bool have_enough_numbers = false;
while (!have_enough_numbers) {
common::MutexLocker locker(&mutex_);
have_enough_numbers = locker.AwaitWithTimeout(
[this, &expected_numbers]() REQUIRES(mutex_) {
return (received_numbers_.size() >= expected_numbers.size());
},
common::FromSeconds(0.1));
}
EXPECT_EQ(expected_numbers, received_numbers_);
}
std::vector<int> received_numbers_;
Mutex mutex_;
};
TEST(ThreadPoolTest, RunTask) {
ThreadPool pool(1);
Receiver receiver;
auto task = common::make_unique<Task>();
task->SetWorkItem([&receiver]() { receiver.Receive(1); });
pool.Schedule(std::move(task));
receiver.WaitForNumberSequence({1});
}
TEST(ThreadPoolTest, RunWithDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithOutOfScopeDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
{
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
}
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
TEST(ThreadPoolTest, RunWithMultipleDependencies) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2a = common::make_unique<Task>();
task_2a->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_2b = common::make_unique<Task>();
task_2b->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto task_3 = common::make_unique<Task>();
task_3->SetWorkItem([&receiver]() { receiver.Receive(3); });
/* -> task_2a \
* task_1 /-> task_2b --> task_3
*/
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2a->AddDependency(weak_task_1);
auto weak_task_2a = pool.Schedule(std::move(task_2a));
task_3->AddDependency(weak_task_1);
task_3->AddDependency(weak_task_2a);
task_2b->AddDependency(weak_task_1);
auto weak_task_2b = pool.Schedule(std::move(task_2b));
task_3->AddDependency(weak_task_2b);
pool.Schedule(std::move(task_3));
receiver.WaitForNumberSequence({1, 2, 2, 3});
}
TEST(ThreadPoolTest, RunWithFinishedDependency) {
ThreadPool pool(2);
Receiver receiver;
auto task_1 = common::make_unique<Task>();
task_1->SetWorkItem([&receiver]() { receiver.Receive(1); });
auto task_2 = common::make_unique<Task>();
task_2->SetWorkItem([&receiver]() { receiver.Receive(2); });
auto weak_task_1 = pool.Schedule(std::move(task_1));
task_2->AddDependency(weak_task_1);
receiver.WaitForNumberSequence({1});
pool.Schedule(std::move(task_2));
receiver.WaitForNumberSequence({1, 2});
}
} // namespace
} // namespace common
} // namespace cartographer