From 02359a98ae2f9d36d2e50731a3a8b090d15d5397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Sch=C3=BCtte?= Date: Tue, 28 Nov 2017 10:50:30 +0100 Subject: [PATCH] Implement connection establishment and server startup and shutdown. (#712) [RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md) --- CMakeLists.txt | 8 ++- .../framework/completion_queue_thread.cc | 10 ++++ .../framework/completion_queue_thread.h | 3 + .../framework/proto/math_service.proto | 30 ++++++++++ cartographer_grpc/framework/rpc.cc | 46 +++++++++++--- cartographer_grpc/framework/rpc.h | 30 +++++++--- cartographer_grpc/framework/rpc_handler.h | 8 ++- cartographer_grpc/framework/server.cc | 47 ++++++++++++++- cartographer_grpc/framework/server.h | 14 +++-- cartographer_grpc/framework/server_test.cc | 46 ++++++++++++++ cartographer_grpc/framework/service.cc | 60 ++++++++++++++++--- cartographer_grpc/framework/service.h | 9 ++- 12 files changed, 279 insertions(+), 32 deletions(-) create mode 100644 cartographer_grpc/framework/proto/math_service.proto create mode 100644 cartographer_grpc/framework/server_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index c553125..cb7b0d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ if (NOT ${BUILD_GRPC}) list_remove_item(ALL_PROTOS ALL_GRPC_FILES) endif() +# TODO(cschuet): Move proto compilation to separate function. set(ALL_PROTO_SRCS) set(ALL_PROTO_HDRS) foreach(ABS_FIL ${ALL_PROTOS}) @@ -110,9 +111,7 @@ if(${BUILD_GRPC}) list(APPEND ALL_GRPC_SERVICE_HDRS "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.grpc.pb.h") add_custom_command( - OUTPUT "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.pb.cc" - "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.pb.h" - "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.grpc.pb.cc" + OUTPUT "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.grpc.pb.cc" "${PROJECT_BINARY_DIR}/${DIR}/${FIL_WE}.grpc.pb.h" COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS --grpc_out ${PROJECT_BINARY_DIR} @@ -152,6 +151,9 @@ foreach(ABS_FIL ${ALL_TESTS}) # Replace slashes as required for CMP0037. string(REPLACE "/" "." TEST_TARGET_NAME "${DIR}/${FIL_WE}") google_test("${TEST_TARGET_NAME}" ${ABS_FIL}) + if(${BUILD_GRPC}) + target_link_libraries("${TEST_TARGET_NAME}" PUBLIC grpc++) + endif() endforeach() target_include_directories(${PROJECT_NAME} SYSTEM PUBLIC diff --git a/cartographer_grpc/framework/completion_queue_thread.cc b/cartographer_grpc/framework/completion_queue_thread.cc index 85ba092..cd9b43a 100644 --- a/cartographer_grpc/framework/completion_queue_thread.cc +++ b/cartographer_grpc/framework/completion_queue_thread.cc @@ -26,11 +26,21 @@ CompletionQueueThread::CompletionQueueThread( std::unique_ptr<::grpc::ServerCompletionQueue> completion_queue) : completion_queue_(std::move(completion_queue)) {} +::grpc::ServerCompletionQueue* CompletionQueueThread::completion_queue() { + return completion_queue_.get(); +} + void CompletionQueueThread::Start(CompletionQueueRunner runner) { CHECK(!worker_thread_); worker_thread_ = cartographer::common::make_unique( [this, runner]() { runner(this->completion_queue_.get()); }); } +void CompletionQueueThread::Shutdown() { + LOG(INFO) << "Shutting down completion queue " << completion_queue_.get(); + completion_queue_->Shutdown(); + worker_thread_->join(); +} + } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/completion_queue_thread.h b/cartographer_grpc/framework/completion_queue_thread.h index a014a70..c39a6d1 100644 --- a/cartographer_grpc/framework/completion_queue_thread.h +++ b/cartographer_grpc/framework/completion_queue_thread.h @@ -32,7 +32,10 @@ class CompletionQueueThread { explicit CompletionQueueThread( std::unique_ptr<::grpc::ServerCompletionQueue> completion_queue); + ::grpc::ServerCompletionQueue* completion_queue(); + void Start(CompletionQueueRunner runner); + void Shutdown(); private: std::unique_ptr<::grpc::ServerCompletionQueue> completion_queue_; diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto new file mode 100644 index 0000000..4761b6c --- /dev/null +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -0,0 +1,30 @@ +// Copyright 2017 The Cartographer Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cartographer_grpc.framework.proto; + +message Request { + int32 input = 1; +} + +message Response { + int32 output = 1; +} + +// Provides information about the gRPC server. +service Math { + rpc GetSum(stream Request) returns (Response); +} diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index 4fbc363..a42c0d6 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -16,25 +16,44 @@ #include "cartographer_grpc/framework/rpc.h" +#include "cartographer/common/make_unique.h" #include "glog/logging.h" -#include "grpc++/impl/codegen/service_type.h" namespace cartographer_grpc { namespace framework { -Rpc::Rpc(const RpcHandlerInfo& rpc_handler_info) - : rpc_handler_info_(rpc_handler_info) {} +Rpc::Rpc(int method_index, + ::grpc::ServerCompletionQueue* server_completion_queue, + const RpcHandlerInfo& rpc_handler_info, Service* service) + : method_index_(method_index), + server_completion_queue_(server_completion_queue), + rpc_handler_info_(rpc_handler_info), + new_connection_state_{State::NEW_CONNECTION, service, this}, + read_state_{State::READ, service, this}, + write_state_{State::WRITE, service, this}, + done_state_{State::DONE, service, this} { + InitializeResponders(rpc_handler_info_.rpc_type); +} + +::grpc::ServerCompletionQueue* Rpc::server_completion_queue() { + return server_completion_queue_; +} ::grpc::internal::RpcMethod::RpcType Rpc::rpc_type() const { return rpc_handler_info_.rpc_type; } -::grpc::internal::ServerAsyncStreamingInterface* Rpc::responder() { - LOG(FATAL) << "Not yet implemented"; - return nullptr; +::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + return server_async_reader_.get(); + default: + LOG(FATAL) << "RPC type not implemented."; + } + LOG(FATAL) << "Never reached."; } -Rpc::RpcState* Rpc::GetState(State state) { +Rpc::RpcState* Rpc::GetRpcState(State state) { switch (state) { case State::NEW_CONNECTION: return &new_connection_state_; @@ -50,6 +69,19 @@ Rpc::RpcState* Rpc::GetState(State state) { ActiveRpcs::ActiveRpcs() : lock_() {} +void Rpc::InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type) { + switch (rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + server_async_reader_ = + cartographer::common::make_unique<::grpc::ServerAsyncReader< + google::protobuf::Message, google::protobuf::Message>>( + &server_context_); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + ActiveRpcs::~ActiveRpcs() { cartographer::common::MutexLocker locker(&lock_); if (!rpcs_.empty()) { diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 0314e0d..73a7d6c 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -22,7 +22,11 @@ #include "cartographer/common/mutex.h" #include "cartographer_grpc/framework/rpc_handler.h" +#include "google/protobuf/message.h" #include "grpc++/grpc++.h" +#include "grpc++/impl/codegen/async_stream.h" +#include "grpc++/impl/codegen/proto_utils.h" +#include "grpc++/impl/codegen/service_type.h" namespace cartographer_grpc { namespace framework { @@ -33,15 +37,20 @@ class Rpc { enum class State { NEW_CONNECTION = 0, READ, WRITE, DONE }; struct RpcState { const State state; + Service* service; Rpc* rpc; }; - Rpc(const RpcHandlerInfo& rpc_handler_info); + Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, + const RpcHandlerInfo& rpc_handler_info, Service* service); + int method_index() const { return method_index_; } + ::grpc::ServerCompletionQueue* server_completion_queue(); ::grpc::internal::RpcMethod::RpcType rpc_type() const; ::grpc::ServerContext* server_context() { return &server_context_; } - ::grpc::internal::ServerAsyncStreamingInterface* responder(); - RpcState* GetState(State state); + ::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); + RpcState* GetRpcState(State state); + const RpcHandlerInfo& rpc_handler_info() const { return rpc_handler_info_; } ::google::protobuf::Message* request() { return request_.get(); } ::google::protobuf::Message* response() { return response_.get(); } @@ -49,17 +58,24 @@ class Rpc { private: Rpc(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete; + void InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type); + int method_index_; + ::grpc::ServerCompletionQueue* server_completion_queue_; RpcHandlerInfo rpc_handler_info_; ::grpc::ServerContext server_context_; - RpcState new_connection_state_ = RpcState{State::NEW_CONNECTION, this}; - RpcState read_state_ = RpcState{State::READ, this}; - RpcState write_state_ = RpcState{State::WRITE, this}; - RpcState done_state_ = RpcState{State::DONE, this}; + RpcState new_connection_state_; + RpcState read_state_; + RpcState write_state_; + RpcState done_state_; std::unique_ptr request_; std::unique_ptr response_; + + std::unique_ptr<::grpc::ServerAsyncReader> + server_async_reader_; }; // This class keeps track of all in-flight RPCs for a 'Service'. Make sure that diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index e5f4668..a26bd23 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -27,7 +27,8 @@ namespace framework { class Rpc; class RpcHandlerInterface { public: - void SetRpc(Rpc* rpc); + virtual ~RpcHandlerInterface() = default; + virtual void SetRpc(Rpc* rpc) = 0; }; using RpcHandlerFactory = @@ -47,6 +48,11 @@ class RpcHandler : public RpcHandlerInterface { using OutgoingType = Outgoing; using RequestType = StripStream; using ResponseType = StripStream; + + void SetRpc(Rpc* rpc) override { rpc_ = rpc; } + + private: + Rpc* rpc_; }; } // namespace framework diff --git a/cartographer_grpc/framework/server.cc b/cartographer_grpc/framework/server.cc index bbaa0d6..63a8dd1 100644 --- a/cartographer_grpc/framework/server.cc +++ b/cartographer_grpc/framework/server.cc @@ -16,7 +16,6 @@ #include "cartographer_grpc/framework/server.h" -#include "cartographer/common/make_unique.h" #include "glog/logging.h" namespace cartographer_grpc { @@ -61,5 +60,51 @@ void Server::AddService( server_builder_.RegisterService(&result.first->second); } +void Server::RunCompletionQueue( + ::grpc::ServerCompletionQueue* completion_queue) { + bool ok; + void* tag; + while (completion_queue->Next(&tag, &ok)) { + auto* rpc_state = static_cast(tag); + rpc_state->service->HandleEvent(rpc_state->state, rpc_state->rpc, ok); + } +} + +void Server::Start() { + // Start the gRPC server process. + server_ = server_builder_.BuildAndStart(); + + // Start serving all services on all completion queues. + for (auto& service : services_) { + service.second.StartServing(completion_queue_threads_); + } + + // Start threads to process all completion queues. + for (auto& completion_queue_threads : completion_queue_threads_) { + completion_queue_threads.Start(Server::RunCompletionQueue); + } +} + +void Server::Shutdown() { + LOG(INFO) << "Shutting down server."; + + // Tell the services to stop serving RPCs. + for (auto& service : services_) { + service.second.StopServing(); + } + + // Shut down the gRPC server waiting for RPCs to finish until the hard + // deadline; then force a shutdown. + server_->Shutdown(); + + // Shut down the server completion queues and wait for the processing threads + // to join. + for (auto& completion_queue_threads : completion_queue_threads_) { + completion_queue_threads.Shutdown(); + } + + LOG(INFO) << "Shutdown complete."; +} + } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index f49528b..2f68f27 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -63,8 +63,8 @@ class Server { rpc_handler->SetRpc(rpc); return rpc_handler; }, - RpcType::value}); + RpcType::value}); } private: @@ -75,8 +75,11 @@ class Server { }; friend class Builder; - // Starts a server and waits for its termination. - void StartAndWait(); + // Starts a server starts serving the registered services. + void Start(); + + // Shuts down the server and all of its services. + void Shutdown(); private: Server(const Options& options); @@ -87,6 +90,9 @@ class Server { const std::string& service_name, const std::map& rpc_handler_infos); + static void RunCompletionQueue( + ::grpc::ServerCompletionQueue* completion_queue); + Options options_; // gRPC objects needed to build a server. diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc new file mode 100644 index 0000000..7180845 --- /dev/null +++ b/cartographer_grpc/framework/server_test.cc @@ -0,0 +1,46 @@ +/* + * Copyright 2017 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cartographer_grpc/framework/server.h" + +#include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" +#include "cartographer_grpc/framework/proto/math_service.pb.h" +#include "cartographer_grpc/framework/rpc_handler.h" +#include "glog/logging.h" +#include "grpc++/grpc++.h" +#include "gtest/gtest.h" + +namespace cartographer_grpc { +namespace framework { +namespace { + +class GetServerOptionsHandler + : public RpcHandler, proto::Response> {}; + +TEST(ServerTest, StartServerTest) { + Server::Builder server_builder; + server_builder.SetServerAddress("0.0.0.0:50051"); + server_builder.SetNumberOfThreads(1); + server_builder.RegisterHandler( + "GetSum"); + std::unique_ptr server = server_builder.Build(); + server->Start(); + server->Shutdown(); +} + +} // namespace +} // namespace framework +} // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 35395ac..a0c0a26 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -37,27 +37,73 @@ Service::Service(const std::string& service_name, } void Service::StartServing( - const std::vector<::grpc::ServerCompletionQueue*>& completion_queues) { + std::vector& completion_queue_threads) { int i = 0; for (const auto& rpc_handler_info : rpc_handler_infos_) { - for (auto completion_queue : completion_queues) { - Rpc* rpc = active_rpcs_.Add( - cartographer::common::make_unique(rpc_handler_info.second)); - RequestNextMethodInvocation(i, rpc, completion_queue); + for (auto& completion_queue_thread : completion_queue_threads) { + Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique( + i, completion_queue_thread.completion_queue(), + rpc_handler_info.second, this)); + RequestNextMethodInvocation(i, rpc, + completion_queue_thread.completion_queue()); } ++i; } } +void Service::StopServing() { shutting_down_ = true; } + +void Service::HandleEvent(Rpc::State state, Rpc* rpc, bool ok) { + switch (state) { + case Rpc::State::NEW_CONNECTION: + HandleNewConnection(rpc, ok); + break; + case Rpc::State::READ: + break; + case Rpc::State::WRITE: + break; + case Rpc::State::DONE: + HandleDone(rpc, ok); + break; + } +} + +void Service::HandleNewConnection(Rpc* rpc, bool ok) { + if (shutting_down_) { + LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs."; + active_rpcs_.Remove(rpc); + return; + } + + if (!ok) { + LOG(ERROR) << "Failed to establish connection for unknown reason."; + active_rpcs_.Remove(rpc); + } + + // TODO(cschuet): Request next read for the new connection. + + // Create new active rpc to handle next connection. + Rpc* next_rpc = active_rpcs_.Add(cartographer::common::make_unique( + rpc->method_index(), rpc->server_completion_queue(), + rpc->rpc_handler_info(), this)); + + RequestNextMethodInvocation(rpc->method_index(), next_rpc, + rpc->server_completion_queue()); +} + +void Service::HandleDone(Rpc* rpc, bool ok) { LOG(FATAL) << "Not implemented"; } + void Service::RequestNextMethodInvocation( int method_index, Rpc* rpc, ::grpc::ServerCompletionQueue* completion_queue) { + rpc->server_context()->AsyncNotifyWhenDone( + rpc->GetRpcState(Rpc::State::DONE)); switch (rpc->rpc_type()) { case ::grpc::internal::RpcMethod::CLIENT_STREAMING: RequestAsyncClientStreaming(method_index, rpc->server_context(), - rpc->responder(), completion_queue, + rpc->streaming_interface(), completion_queue, completion_queue, - rpc->GetState(Rpc::State::NEW_CONNECTION)); + rpc->GetRpcState(Rpc::State::NEW_CONNECTION)); break; default: LOG(FATAL) << "RPC type not implemented."; diff --git a/cartographer_grpc/framework/service.h b/cartographer_grpc/framework/service.h index 6876bca..533fb9b 100644 --- a/cartographer_grpc/framework/service.h +++ b/cartographer_grpc/framework/service.h @@ -32,16 +32,21 @@ class Service : public ::grpc::Service { public: Service(const std::string& service_name, const std::map& rpc_handlers); - void StartServing( - const std::vector<::grpc::ServerCompletionQueue*>& completion_queues); + void StartServing(std::vector& completion_queues); + void HandleEvent(Rpc::State state, Rpc* rpc, bool ok); + void StopServing(); private: void RequestNextMethodInvocation( int method_index, Rpc* rpc, ::grpc::ServerCompletionQueue* completion_queue); + void HandleNewConnection(Rpc* rpc, bool ok); + void HandleDone(Rpc* rpc, bool ok); + std::map rpc_handler_infos_; ActiveRpcs active_rpcs_; + bool shutting_down_ = false; }; } // namespace framework