diff --git a/cartographer/common/internal/testing/thread_pool_for_testing.cc b/cartographer/common/internal/testing/thread_pool_for_testing.cc index eec5e2d..28dbbaa 100644 --- a/cartographer/common/internal/testing/thread_pool_for_testing.cc +++ b/cartographer/common/internal/testing/thread_pool_for_testing.cc @@ -21,6 +21,9 @@ #include #include +#include "cartographer/common/make_unique.h" +#include "cartographer/common/task.h" +#include "cartographer/common/time.h" #include "glog/logging.h" namespace cartographer { @@ -35,16 +38,40 @@ ThreadPoolForTesting::~ThreadPoolForTesting() { MutexLocker locker(&mutex_); CHECK(running_); running_ = false; - CHECK_EQ(work_queue_.size(), 0); + CHECK_EQ(task_queue_.size(), 0); + CHECK_EQ(tasks_not_ready_.size(), 0); } thread_.join(); } -void ThreadPoolForTesting::Schedule(const std::function &work_item) { +void ThreadPoolForTesting::NotifyDependenciesCompleted(Task* task) { MutexLocker locker(&mutex_); - idle_ = false; CHECK(running_); - work_queue_.push_back(work_item); + auto it = tasks_not_ready_.find(task); + CHECK(it != tasks_not_ready_.end()); + task_queue_.push_back(it->second); + tasks_not_ready_.erase(it); +} + +void ThreadPoolForTesting::Schedule(const std::function& work_item) { + auto task = common::make_unique(); + task->SetWorkItem(work_item); + Schedule(std::move(task)); +} + +std::weak_ptr ThreadPoolForTesting::Schedule(std::unique_ptr task) { + std::shared_ptr shared_task; + { + MutexLocker locker(&mutex_); + idle_ = false; + CHECK(running_); + auto insert_result = + tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task))); + CHECK(insert_result.second) << "ScheduleWhenReady called twice"; + shared_task = insert_result.first->second; + } + SetThreadPool(shared_task.get()); + return shared_task; } void ThreadPoolForTesting::WaitUntilIdle() { @@ -61,25 +88,25 @@ void ThreadPoolForTesting::WaitUntilIdle() { void ThreadPoolForTesting::DoWork() { for (;;) { - std::function work_item; + std::shared_ptr task; { MutexLocker locker(&mutex_); locker.AwaitWithTimeout( [this]() - REQUIRES(mutex_) { return !work_queue_.empty() || !running_; }, + REQUIRES(mutex_) { return !task_queue_.empty() || !running_; }, common::FromSeconds(0.1)); - if (!work_queue_.empty()) { - work_item = work_queue_.front(); - work_queue_.pop_front(); + if (!task_queue_.empty()) { + task = task_queue_.front(); + task_queue_.pop_front(); } if (!running_) { return; } - if (work_queue_.empty() && !work_item) { + if (tasks_not_ready_.empty() && task_queue_.empty() && !task) { idle_ = true; } } - if (work_item) work_item(); + if (task) Execute(task.get()); } } diff --git a/cartographer/common/internal/testing/thread_pool_for_testing.h b/cartographer/common/internal/testing/thread_pool_for_testing.h index 8a8778c..972a1ff 100644 --- a/cartographer/common/internal/testing/thread_pool_for_testing.h +++ b/cartographer/common/internal/testing/thread_pool_for_testing.h @@ -19,8 +19,8 @@ #include #include +#include #include -#include #include "cartographer/common/mutex.h" #include "cartographer/common/thread_pool.h" @@ -34,26 +34,24 @@ class ThreadPoolForTesting : public ThreadPoolInterface { ThreadPoolForTesting(); ~ThreadPoolForTesting(); - void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { - LOG(FATAL) << "not implemented"; - } + void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override; void Schedule(const std::function& work_item) override; + std::weak_ptr Schedule(std::unique_ptr task) - EXCLUDES(mutex_) override { - LOG(FATAL) << "not implemented"; - } + EXCLUDES(mutex_) override; void WaitUntilIdle(); private: void DoWork(); - std::thread thread_ GUARDED_BY(mutex_); + Mutex mutex_; bool running_ GUARDED_BY(mutex_) = true; bool idle_ GUARDED_BY(mutex_) = true; - std::deque> work_queue_ GUARDED_BY(mutex_); - Mutex mutex_; + std::deque> task_queue_ GUARDED_BY(mutex_); + std::map> tasks_not_ready_ GUARDED_BY(mutex_); + std::thread thread_ GUARDED_BY(mutex_); }; } // namespace testing diff --git a/cartographer/common/task.h b/cartographer/common/task.h index 360989d..bd84699 100644 --- a/cartographer/common/task.h +++ b/cartographer/common/task.h @@ -35,6 +35,7 @@ class Task { using WorkItem = std::function; enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED }; + Task() = default; ~Task(); State GetState() EXCLUDES(mutex_); diff --git a/cartographer/common/task_test.cc b/cartographer/common/task_test.cc index 69eb54c..e7a3212 100644 --- a/cartographer/common/task_test.cc +++ b/cartographer/common/task_test.cc @@ -46,11 +46,13 @@ class FakeThreadPool : public ThreadPoolInterface { } std::weak_ptr Schedule(std::unique_ptr task) override { + std::shared_ptr shared_task; auto it = tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task))); EXPECT_TRUE(it.second); - SetThreadPool(it.first->first); - return it.first->second; + shared_task = it.first->second; + SetThreadPool(shared_task.get()); + return shared_task; } void RunNext() { diff --git a/cartographer/common/thread_pool.cc b/cartographer/common/thread_pool.cc index d36c83a..e21d784 100644 --- a/cartographer/common/thread_pool.cc +++ b/cartographer/common/thread_pool.cc @@ -21,6 +21,8 @@ #include #include +#include "cartographer/common/make_unique.h" +#include "cartographer/common/task.h" #include "glog/logging.h" namespace cartographer { @@ -44,17 +46,41 @@ ThreadPool::~ThreadPool() { MutexLocker locker(&mutex_); CHECK(running_); running_ = false; - CHECK_EQ(work_queue_.size(), 0); + CHECK_EQ(task_queue_.size(), 0); + CHECK_EQ(tasks_not_ready_.size(), 0); } for (std::thread& thread : pool_) { thread.join(); } } -void ThreadPool::Schedule(const std::function& work_item) { +void ThreadPool::NotifyDependenciesCompleted(Task* task) { MutexLocker locker(&mutex_); CHECK(running_); - work_queue_.push_back(work_item); + auto it = tasks_not_ready_.find(task); + CHECK(it != tasks_not_ready_.end()); + task_queue_.push_back(it->second); + tasks_not_ready_.erase(it); +} + +void ThreadPool::Schedule(const std::function& work_item) { + auto task = make_unique(); + task->SetWorkItem(work_item); + Schedule(std::move(task)); +} + +std::weak_ptr ThreadPool::Schedule(std::unique_ptr task) { + std::shared_ptr shared_task; + { + MutexLocker locker(&mutex_); + CHECK(running_); + auto insert_result = + tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task))); + CHECK(insert_result.second) << "Schedule called twice"; + shared_task = insert_result.first->second; + } + SetThreadPool(shared_task.get()); + return shared_task; } void ThreadPool::DoWork() { @@ -65,21 +91,22 @@ void ThreadPool::DoWork() { CHECK_NE(nice(10), -1); #endif for (;;) { - std::function work_item; + std::shared_ptr task; { MutexLocker locker(&mutex_); locker.Await([this]() REQUIRES(mutex_) { - return !work_queue_.empty() || !running_; + return !task_queue_.empty() || !running_; }); - if (!work_queue_.empty()) { - work_item = work_queue_.front(); - work_queue_.pop_front(); + if (!task_queue_.empty()) { + task = std::move(task_queue_.front()); + task_queue_.pop_front(); } else if (!running_) { return; } } - CHECK(work_item); - work_item(); + CHECK(task); + CHECK_EQ(task->GetState(), common::Task::DEPENDENCIES_COMPLETED); + Execute(task.get()); } } diff --git a/cartographer/common/thread_pool.h b/cartographer/common/thread_pool.h index 2121629..6dc1114 100644 --- a/cartographer/common/thread_pool.h +++ b/cartographer/common/thread_pool.h @@ -19,9 +19,9 @@ #include #include -#include #include #include +#include #include #include "cartographer/common/mutex.h" @@ -50,11 +50,12 @@ class ThreadPoolInterface { virtual void NotifyDependenciesCompleted(Task* task) = 0; }; -// A fixed number of threads working on a work queue of work items. Adding a -// new work item does not block, and will be executed by a background thread -// eventually. The queue must be empty before calling the destructor. The thread -// pool will then wait for the currently executing work items to finish and then -// destroy the threads. +// A fixed number of threads working on tasks. Adding a task does not block. +// Tasks may be added whether or not their dependencies are completed. +// When all dependencies of a task are completed, it is queued up for execution +// in a background thread. The queue must be empty before calling the +// destructor. The thread pool will then wait for the currently executing work +// items to finish and then destroy the threads. class ThreadPool : public ThreadPoolInterface { public: explicit ThreadPool(int num_threads); @@ -69,21 +70,19 @@ class ThreadPool : public ThreadPoolInterface { // When the returned weak pointer is expired, 'task' has certainly completed, // so dependants no longer need to add it as a dependency. std::weak_ptr Schedule(std::unique_ptr task) - EXCLUDES(mutex_) override { - LOG(FATAL) << "not implemented"; - } + EXCLUDES(mutex_) override; private: void DoWork(); - void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { - LOG(FATAL) << "not implemented"; - } + void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override; Mutex mutex_; bool running_ GUARDED_BY(mutex_) = true; std::vector pool_ GUARDED_BY(mutex_); - std::deque> work_queue_ GUARDED_BY(mutex_); + std::deque> task_queue_ GUARDED_BY(mutex_); + std::unordered_map> tasks_not_ready_ + GUARDED_BY(mutex_); }; } // namespace common diff --git a/cartographer/common/thread_pool_test.cc b/cartographer/common/thread_pool_test.cc new file mode 100644 index 0000000..e2aeece --- /dev/null +++ b/cartographer/common/thread_pool_test.cc @@ -0,0 +1,131 @@ +/* + * Copyright 2018 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cartographer/common/thread_pool.h" + +#include + +#include "cartographer/common/make_unique.h" +#include "gtest/gtest.h" + +namespace cartographer { +namespace common { +namespace { + +class Receiver { + public: + void Receive(int number) { + Mutex::Locker locker(&mutex_); + received_numbers_.push_back(number); + } + + void WaitForNumberSequence(const std::vector& expected_numbers) { + bool have_enough_numbers = false; + while (!have_enough_numbers) { + common::MutexLocker locker(&mutex_); + have_enough_numbers = locker.AwaitWithTimeout( + [this, &expected_numbers]() REQUIRES(mutex_) { + return (received_numbers_.size() >= expected_numbers.size()); + }, + common::FromSeconds(0.1)); + } + EXPECT_EQ(expected_numbers, received_numbers_); + } + + std::vector received_numbers_; + Mutex mutex_; +}; + +TEST(ThreadPoolTest, RunTask) { + ThreadPool pool(1); + Receiver receiver; + auto task = common::make_unique(); + task->SetWorkItem([&receiver]() { receiver.Receive(1); }); + pool.Schedule(std::move(task)); + receiver.WaitForNumberSequence({1}); +} + +TEST(ThreadPoolTest, RunWithDependency) { + ThreadPool pool(2); + Receiver receiver; + auto task_2 = common::make_unique(); + task_2->SetWorkItem([&receiver]() { receiver.Receive(2); }); + auto task_1 = common::make_unique(); + task_1->SetWorkItem([&receiver]() { receiver.Receive(1); }); + auto weak_task_1 = pool.Schedule(std::move(task_1)); + task_2->AddDependency(weak_task_1); + pool.Schedule(std::move(task_2)); + receiver.WaitForNumberSequence({1, 2}); +} + +TEST(ThreadPoolTest, RunWithOutOfScopeDependency) { + ThreadPool pool(2); + Receiver receiver; + auto task_2 = common::make_unique(); + task_2->SetWorkItem([&receiver]() { receiver.Receive(2); }); + { + auto task_1 = common::make_unique(); + task_1->SetWorkItem([&receiver]() { receiver.Receive(1); }); + auto weak_task_1 = pool.Schedule(std::move(task_1)); + task_2->AddDependency(weak_task_1); + } + pool.Schedule(std::move(task_2)); + receiver.WaitForNumberSequence({1, 2}); +} + +TEST(ThreadPoolTest, RunWithMultipleDependencies) { + ThreadPool pool(2); + Receiver receiver; + auto task_1 = common::make_unique(); + task_1->SetWorkItem([&receiver]() { receiver.Receive(1); }); + auto task_2a = common::make_unique(); + task_2a->SetWorkItem([&receiver]() { receiver.Receive(2); }); + auto task_2b = common::make_unique(); + task_2b->SetWorkItem([&receiver]() { receiver.Receive(2); }); + auto task_3 = common::make_unique(); + task_3->SetWorkItem([&receiver]() { receiver.Receive(3); }); + /* -> task_2a \ + * task_1 /-> task_2b --> task_3 + */ + auto weak_task_1 = pool.Schedule(std::move(task_1)); + task_2a->AddDependency(weak_task_1); + auto weak_task_2a = pool.Schedule(std::move(task_2a)); + task_3->AddDependency(weak_task_1); + task_3->AddDependency(weak_task_2a); + task_2b->AddDependency(weak_task_1); + auto weak_task_2b = pool.Schedule(std::move(task_2b)); + task_3->AddDependency(weak_task_2b); + pool.Schedule(std::move(task_3)); + receiver.WaitForNumberSequence({1, 2, 2, 3}); +} + +TEST(ThreadPoolTest, RunWithFinishedDependency) { + ThreadPool pool(2); + Receiver receiver; + auto task_1 = common::make_unique(); + task_1->SetWorkItem([&receiver]() { receiver.Receive(1); }); + auto task_2 = common::make_unique(); + task_2->SetWorkItem([&receiver]() { receiver.Receive(2); }); + auto weak_task_1 = pool.Schedule(std::move(task_1)); + task_2->AddDependency(weak_task_1); + receiver.WaitForNumberSequence({1}); + pool.Schedule(std::move(task_2)); + receiver.WaitForNumberSequence({1, 2}); +} + +} // namespace +} // namespace common +} // namespace cartographer