From 3a46804393a2372cc54074adadf550a74ab71e77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Sch=C3=BCtte?= Date: Wed, 29 Nov 2017 10:40:26 +0100 Subject: [PATCH] Implement end-to-end client streaming RPC. (#713) [RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md) --- .../framework/proto/math_service.proto | 6 +- cartographer_grpc/framework/rpc.cc | 99 +++++++++++++++---- cartographer_grpc/framework/rpc.h | 52 ++++++---- cartographer_grpc/framework/rpc_handler.h | 28 +++--- .../framework/rpc_handler_interface.h | 49 +++++++++ cartographer_grpc/framework/server.cc | 5 +- cartographer_grpc/framework/server_test.cc | 71 +++++++++++-- cartographer_grpc/framework/service.cc | 77 +++++++++------ cartographer_grpc/framework/service.h | 13 ++- 9 files changed, 290 insertions(+), 110 deletions(-) create mode 100644 cartographer_grpc/framework/rpc_handler_interface.h diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto index 4761b6c..ae5ee37 100644 --- a/cartographer_grpc/framework/proto/math_service.proto +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -16,15 +16,15 @@ syntax = "proto3"; package cartographer_grpc.framework.proto; -message Request { +message GetSumRequest { int32 input = 1; } -message Response { +message GetSumResponse { int32 output = 1; } // Provides information about the gRPC server. service Math { - rpc GetSum(stream Request) returns (Response); + rpc GetSum(stream GetSumRequest) returns (GetSumResponse); } diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index a42c0d6..f616e8a 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -15,6 +15,7 @@ */ #include "cartographer_grpc/framework/rpc.h" +#include "cartographer_grpc/framework/service.h" #include "cartographer/common/make_unique.h" #include "glog/logging.h" @@ -28,19 +29,65 @@ Rpc::Rpc(int method_index, : method_index_(method_index), server_completion_queue_(server_completion_queue), rpc_handler_info_(rpc_handler_info), - new_connection_state_{State::NEW_CONNECTION, service, this}, - read_state_{State::READ, service, this}, - write_state_{State::WRITE, service, this}, - done_state_{State::DONE, service, this} { - InitializeResponders(rpc_handler_info_.rpc_type); + service_(service), + new_connection_event_{Event::NEW_CONNECTION, this, false}, + read_event_{Event::READ, this, false}, + write_event_{Event::WRITE, this, false}, + done_event_{Event::DONE, this, false}, + handler_(rpc_handler_info_.rpc_handler_factory(this)) { + InitializeReadersAndWriters(rpc_handler_info_.rpc_type); + + // Initialize the prototypical request and response messages. + request_.reset(::google::protobuf::MessageFactory::generated_factory() + ->GetPrototype(rpc_handler_info_.request_descriptor) + ->New()); + response_.reset(::google::protobuf::MessageFactory::generated_factory() + ->GetPrototype(rpc_handler_info_.response_descriptor) + ->New()); } -::grpc::ServerCompletionQueue* Rpc::server_completion_queue() { - return server_completion_queue_; +std::unique_ptr Rpc::Clone() { + return cartographer::common::make_unique( + method_index_, server_completion_queue_, rpc_handler_info_, service_); } -::grpc::internal::RpcMethod::RpcType Rpc::rpc_type() const { - return rpc_handler_info_.rpc_type; +void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } + +void Rpc::OnReadsDone() { handler_->OnReadsDone(); } + +void Rpc::RequestNextMethodInvocation() { + done_event_.pending = true; + new_connection_event_.pending = true; + server_context_.AsyncNotifyWhenDone(&done_event_); + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + service_->RequestAsyncClientStreaming( + method_index_, &server_context_, streaming_interface(), + server_completion_queue_, server_completion_queue_, + &new_connection_event_); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + +void Rpc::RequestStreamingReadIfNeeded() { + // For request-streaming RPCs ask the client to start sending requests. + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + read_event_.pending = true; + async_reader_interface()->Read(request_.get(), &read_event_); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + +void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { + response_ = std::move(message); + write_event_.pending = true; + server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, + &write_event_); } ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { @@ -53,23 +100,35 @@ Rpc::Rpc(int method_index, LOG(FATAL) << "Never reached."; } -Rpc::RpcState* Rpc::GetRpcState(State state) { - switch (state) { - case State::NEW_CONNECTION: - return &new_connection_state_; - case State::READ: - return &read_state_; - case State::WRITE: - return &write_state_; - case State::DONE: - return &done_state_; +::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* +Rpc::async_reader_interface() { + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + return server_async_reader_.get(); + default: + LOG(FATAL) << "RPC type not implemented."; + } + LOG(FATAL) << "Never reached."; +} + +Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) { + switch (event) { + case Event::NEW_CONNECTION: + return &new_connection_event_; + case Event::READ: + return &read_event_; + case Event::WRITE: + return &write_event_; + case Event::DONE: + return &done_event_; } LOG(FATAL) << "Never reached."; } ActiveRpcs::ActiveRpcs() : lock_() {} -void Rpc::InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type) { +void Rpc::InitializeReadersAndWriters( + ::grpc::internal::RpcMethod::RpcType rpc_type) { switch (rpc_type) { case ::grpc::internal::RpcMethod::CLIENT_STREAMING: server_async_reader_ = diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 73a7d6c..be89118 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -21,10 +21,11 @@ #include #include "cartographer/common/mutex.h" -#include "cartographer_grpc/framework/rpc_handler.h" +#include "cartographer_grpc/framework/rpc_handler_interface.h" #include "google/protobuf/message.h" #include "grpc++/grpc++.h" #include "grpc++/impl/codegen/async_stream.h" +#include "grpc++/impl/codegen/async_unary_call.h" #include "grpc++/impl/codegen/proto_utils.h" #include "grpc++/impl/codegen/service_type.h" @@ -34,45 +35,54 @@ namespace framework { class Service; class Rpc { public: - enum class State { NEW_CONNECTION = 0, READ, WRITE, DONE }; - struct RpcState { - const State state; - Service* service; + enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE }; + struct RpcEvent { + const Event event; Rpc* rpc; + // Indicates whether the event is pending completion. E.g. 'event = READ' + // and 'pending = true' means that a read has been requested but hasn't + // completed yet. While 'pending = false' indicates, that the read has + // completed and currently no read is in-flight. + bool pending; }; Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, const RpcHandlerInfo& rpc_handler_info, Service* service); - - int method_index() const { return method_index_; } - ::grpc::ServerCompletionQueue* server_completion_queue(); - ::grpc::internal::RpcMethod::RpcType rpc_type() const; - ::grpc::ServerContext* server_context() { return &server_context_; } - ::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); - RpcState* GetRpcState(State state); - const RpcHandlerInfo& rpc_handler_info() const { return rpc_handler_info_; } - - ::google::protobuf::Message* request() { return request_.get(); } - ::google::protobuf::Message* response() { return response_.get(); } + std::unique_ptr Clone(); + void OnRequest(); + void OnReadsDone(); + void RequestNextMethodInvocation(); + void RequestStreamingReadIfNeeded(); + void Write(std::unique_ptr<::google::protobuf::Message> message); + Service* service() { return service_; } + RpcEvent* GetRpcEvent(Event event); private: Rpc(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete; - void InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type); + void InitializeReadersAndWriters( + ::grpc::internal::RpcMethod::RpcType rpc_type); + + ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* + async_reader_interface(); + ::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); int method_index_; ::grpc::ServerCompletionQueue* server_completion_queue_; RpcHandlerInfo rpc_handler_info_; + Service* service_; ::grpc::ServerContext server_context_; - RpcState new_connection_state_; - RpcState read_state_; - RpcState write_state_; - RpcState done_state_; + RpcEvent new_connection_event_; + RpcEvent read_event_; + RpcEvent write_event_; + RpcEvent done_event_; std::unique_ptr request_; std::unique_ptr response_; + std::unique_ptr handler_; + std::unique_ptr<::grpc::ServerAsyncReader> server_async_reader_; diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index a26bd23..7680fc8 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -17,30 +17,16 @@ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H +#include "cartographer_grpc/framework/rpc.h" +#include "cartographer_grpc/framework/rpc_handler_interface.h" #include "cartographer_grpc/framework/type_traits.h" +#include "glog/logging.h" #include "google/protobuf/message.h" #include "grpc++/grpc++.h" namespace cartographer_grpc { namespace framework { -class Rpc; -class RpcHandlerInterface { - public: - virtual ~RpcHandlerInterface() = default; - virtual void SetRpc(Rpc* rpc) = 0; -}; - -using RpcHandlerFactory = - std::function(Rpc*)>; - -struct RpcHandlerInfo { - const google::protobuf::Descriptor* request_descriptor; - const google::protobuf::Descriptor* response_descriptor; - const RpcHandlerFactory rpc_handler_factory; - const grpc::internal::RpcMethod::RpcType rpc_type; -}; - template class RpcHandler : public RpcHandlerInterface { public: @@ -50,6 +36,14 @@ class RpcHandler : public RpcHandlerInterface { using ResponseType = StripStream; void SetRpc(Rpc* rpc) override { rpc_ = rpc; } + void OnRequestInternal(const ::google::protobuf::Message* request) override { + DCHECK(dynamic_cast(request)); + OnRequest(static_cast(*request)); + } + virtual void OnRequest(const RequestType& request) = 0; + void Send(std::unique_ptr response) { + rpc_->Write(std::move(response)); + } private: Rpc* rpc_; diff --git a/cartographer_grpc/framework/rpc_handler_interface.h b/cartographer_grpc/framework/rpc_handler_interface.h new file mode 100644 index 0000000..fb36971 --- /dev/null +++ b/cartographer_grpc/framework/rpc_handler_interface.h @@ -0,0 +1,49 @@ +/* + * Copyright 2017 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H +#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H + +#include "google/protobuf/message.h" +#include "grpc++/grpc++.h" + +namespace cartographer_grpc { +namespace framework { + +class Rpc; +class RpcHandlerInterface { + public: + virtual ~RpcHandlerInterface() = default; + virtual void SetRpc(Rpc* rpc) = 0; + virtual void OnRequestInternal( + const ::google::protobuf::Message* request) = 0; + virtual void OnReadsDone() = 0; +}; + +using RpcHandlerFactory = + std::function(Rpc*)>; + +struct RpcHandlerInfo { + const google::protobuf::Descriptor* request_descriptor; + const google::protobuf::Descriptor* response_descriptor; + const RpcHandlerFactory rpc_handler_factory; + const grpc::internal::RpcMethod::RpcType rpc_type; +}; + +} // namespace framework +} // namespace cartographer_grpc + +#endif // CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H diff --git a/cartographer_grpc/framework/server.cc b/cartographer_grpc/framework/server.cc index 63a8dd1..507af64 100644 --- a/cartographer_grpc/framework/server.cc +++ b/cartographer_grpc/framework/server.cc @@ -65,8 +65,9 @@ void Server::RunCompletionQueue( bool ok; void* tag; while (completion_queue->Next(&tag, &ok)) { - auto* rpc_state = static_cast(tag); - rpc_state->service->HandleEvent(rpc_state->state, rpc_state->rpc, ok); + auto* rpc_event = static_cast(tag); + rpc_event->rpc->service()->HandleEvent(rpc_event->event, rpc_event->rpc, + ok); } } diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 7180845..02ad2a1 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -28,17 +28,68 @@ namespace framework { namespace { class GetServerOptionsHandler - : public RpcHandler, proto::Response> {}; + : public RpcHandler, proto::GetSumResponse> { + public: + void OnRequest(const proto::GetSumRequest& request) override { + sum_ += request.input(); + } -TEST(ServerTest, StartServerTest) { - Server::Builder server_builder; - server_builder.SetServerAddress("0.0.0.0:50051"); - server_builder.SetNumberOfThreads(1); - server_builder.RegisterHandler( - "GetSum"); - std::unique_ptr server = server_builder.Build(); - server->Start(); - server->Shutdown(); + void OnReadsDone() override { + auto response = cartographer::common::make_unique(); + response->set_output(sum_); + Send(std::move(response)); + } + + private: + int sum_ = 0; +}; + +// TODO(cschuet): Due to the hard-coded part these tests will become flaky when +// run in parallel. It would be nice to find a way to solve that. gRPC also +// allows to communicate over UNIX domain sockets. +const std::string kServerAddress = "localhost:50051"; +const std::size_t kNumThreads = 1; + +class ServerTest : public ::testing::Test { + protected: + void SetUp() override { + Server::Builder server_builder; + server_builder.SetServerAddress(kServerAddress); + server_builder.SetNumberOfThreads(kNumThreads); + server_builder.RegisterHandler( + "GetSum"); + server_ = server_builder.Build(); + } + + std::unique_ptr server_; +}; + +TEST_F(ServerTest, StartAndStopServerTest) { + server_->Start(); + server_->Shutdown(); +} + +TEST_F(ServerTest, ProcessRpcStreamTest) { + server_->Start(); + + auto channel = + grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials()); + std::unique_ptr stub(proto::Math::NewStub(channel)); + grpc::ClientContext context; + proto::GetSumResponse result; + std::unique_ptr > writer( + stub->GetSum(&context, &result)); + for (int i = 0; i < 3; ++i) { + proto::GetSumRequest request; + request.set_input(i); + EXPECT_TRUE(writer->Write(request)); + } + writer->WritesDone(); + grpc::Status status = writer->Finish(); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(result.output(), 3); + + server_->Shutdown(); } } // namespace diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index a0c0a26..543c01e 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -44,8 +44,7 @@ void Service::StartServing( Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique( i, completion_queue_thread.completion_queue(), rpc_handler_info.second, this)); - RequestNextMethodInvocation(i, rpc, - completion_queue_thread.completion_queue()); + rpc->RequestNextMethodInvocation(); } ++i; } @@ -53,16 +52,19 @@ void Service::StartServing( void Service::StopServing() { shutting_down_ = true; } -void Service::HandleEvent(Rpc::State state, Rpc* rpc, bool ok) { - switch (state) { - case Rpc::State::NEW_CONNECTION: +void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) { + rpc->GetRpcEvent(event)->pending = false; + switch (event) { + case Rpc::Event::NEW_CONNECTION: HandleNewConnection(rpc, ok); break; - case Rpc::State::READ: + case Rpc::Event::READ: + HandleRead(rpc, ok); break; - case Rpc::State::WRITE: + case Rpc::Event::WRITE: + HandleWrite(rpc, ok); break; - case Rpc::State::DONE: + case Rpc::Event::DONE: HandleDone(rpc, ok); break; } @@ -70,7 +72,9 @@ void Service::HandleEvent(Rpc::State state, Rpc* rpc, bool ok) { void Service::HandleNewConnection(Rpc* rpc, bool ok) { if (shutting_down_) { - LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs."; + if (ok) { + LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs."; + } active_rpcs_.Remove(rpc); return; } @@ -80,33 +84,42 @@ void Service::HandleNewConnection(Rpc* rpc, bool ok) { active_rpcs_.Remove(rpc); } - // TODO(cschuet): Request next read for the new connection. + if (ok) { + // For request-streaming RPCs ask the client to start sending requests. + rpc->RequestStreamingReadIfNeeded(); + } - // Create new active rpc to handle next connection. - Rpc* next_rpc = active_rpcs_.Add(cartographer::common::make_unique( - rpc->method_index(), rpc->server_completion_queue(), - rpc->rpc_handler_info(), this)); - - RequestNextMethodInvocation(rpc->method_index(), next_rpc, - rpc->server_completion_queue()); + // Create new active rpc to handle next connection and register it for the + // incoming connection. + active_rpcs_.Add(rpc->Clone())->RequestNextMethodInvocation(); } -void Service::HandleDone(Rpc* rpc, bool ok) { LOG(FATAL) << "Not implemented"; } +void Service::HandleRead(Rpc* rpc, bool ok) { + if (ok) { + rpc->OnRequest(); + rpc->RequestStreamingReadIfNeeded(); + return; + } -void Service::RequestNextMethodInvocation( - int method_index, Rpc* rpc, - ::grpc::ServerCompletionQueue* completion_queue) { - rpc->server_context()->AsyncNotifyWhenDone( - rpc->GetRpcState(Rpc::State::DONE)); - switch (rpc->rpc_type()) { - case ::grpc::internal::RpcMethod::CLIENT_STREAMING: - RequestAsyncClientStreaming(method_index, rpc->server_context(), - rpc->streaming_interface(), completion_queue, - completion_queue, - rpc->GetRpcState(Rpc::State::NEW_CONNECTION)); - break; - default: - LOG(FATAL) << "RPC type not implemented."; + // Reads completed. + rpc->OnReadsDone(); +} + +void Service::HandleWrite(Rpc* rpc, bool ok) { + if (!ok) { + LOG(ERROR) << "Write failed"; + } + + RemoveIfNotPending(rpc); +} + +void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); } + +void Service::RemoveIfNotPending(Rpc* rpc) { + if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending && + !rpc->GetRpcEvent(Rpc::Event::READ)->pending && + !rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) { + active_rpcs_.Remove(rpc); } } diff --git a/cartographer_grpc/framework/service.h b/cartographer_grpc/framework/service.h index 533fb9b..00c9909 100644 --- a/cartographer_grpc/framework/service.h +++ b/cartographer_grpc/framework/service.h @@ -17,6 +17,7 @@ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H #define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H +#include "cartographer_grpc/framework/completion_queue_thread.h" #include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc_handler.h" #include "grpc++/impl/codegen/service_type.h" @@ -30,20 +31,22 @@ namespace framework { // 'Rpc' handler objects. class Service : public ::grpc::Service { public: + friend class Rpc; + Service(const std::string& service_name, const std::map& rpc_handlers); void StartServing(std::vector& completion_queues); - void HandleEvent(Rpc::State state, Rpc* rpc, bool ok); + void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok); void StopServing(); private: - void RequestNextMethodInvocation( - int method_index, Rpc* rpc, - ::grpc::ServerCompletionQueue* completion_queue); - void HandleNewConnection(Rpc* rpc, bool ok); + void HandleRead(Rpc* rpc, bool ok); + void HandleWrite(Rpc* rpc, bool ok); void HandleDone(Rpc* rpc, bool ok); + void RemoveIfNotPending(Rpc* rpc); + std::map rpc_handler_infos_; ActiveRpcs active_rpcs_; bool shutting_down_ = false;