diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto index ae5ee37..18a182b 100644 --- a/cartographer_grpc/framework/proto/math_service.proto +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -24,7 +24,16 @@ message GetSumResponse { int32 output = 1; } +message GetSquareRequest { + int32 input = 1; +} + +message GetSquareResponse { + int32 output = 1; +} + // Provides information about the gRPC server. service Math { rpc GetSum(stream GetSumRequest) returns (GetSumResponse); + rpc GetSquare(GetSquareRequest) returns (GetSquareResponse); } diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index b8ab2ff..b9d400b 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -35,6 +35,7 @@ Rpc::Rpc(int method_index, new_connection_event_{Event::NEW_CONNECTION, this, false}, read_event_{Event::READ, this, false}, write_event_{Event::WRITE, this, false}, + finish_event_{Event::FINISH, this, false}, done_event_{Event::DONE, this, false}, handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { InitializeReadersAndWriters(rpc_handler_info_.rpc_type); @@ -69,6 +70,12 @@ void Rpc::RequestNextMethodInvocation() { server_completion_queue_, server_completion_queue_, &new_connection_event_); break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + service_->RequestAsyncUnary( + method_index_, &server_context_, request_.get(), + streaming_interface(), server_completion_queue_, + server_completion_queue_, &new_connection_event_); + break; default: LOG(FATAL) << "RPC type not implemented."; } @@ -81,6 +88,13 @@ void Rpc::RequestStreamingReadIfNeeded() { read_event_.pending = true; async_reader_interface()->Read(request_.get(), &read_event_); break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + // For NORMAL_RPC we don't have to do anything here, since gRPC + // automatically issues a READ request and places the request into the + // 'Message' we provided to 'RequestAsyncUnary' above. + OnRequest(); + OnReadsDone(); + break; default: LOG(FATAL) << "RPC type not implemented."; } @@ -88,15 +102,28 @@ void Rpc::RequestStreamingReadIfNeeded() { void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { response_ = std::move(message); - write_event_.pending = true; - server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, - &write_event_); + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, + &finish_event_); + finish_event_.pending = true; + break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + server_async_response_writer_->Finish(*response_.get(), + ::grpc::Status::OK, &finish_event_); + finish_event_.pending = true; + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } } ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { switch (rpc_handler_info_.rpc_type) { case ::grpc::internal::RpcMethod::CLIENT_STREAMING: return server_async_reader_.get(); + case ::grpc::internal::RpcMethod::NORMAL_RPC: + return server_async_response_writer_.get(); default: LOG(FATAL) << "RPC type not implemented."; } @@ -108,6 +135,8 @@ Rpc::async_reader_interface() { switch (rpc_handler_info_.rpc_type) { case ::grpc::internal::RpcMethod::CLIENT_STREAMING: return server_async_reader_.get(); + case ::grpc::internal::RpcMethod::NORMAL_RPC: + LOG(FATAL) << "For NORMAL_RPC no streaming interface exists."; default: LOG(FATAL) << "RPC type not implemented."; } @@ -122,6 +151,8 @@ Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) { return &read_event_; case Event::WRITE: return &write_event_; + case Event::FINISH: + return &finish_event_; case Event::DONE: return &done_event_; } @@ -139,6 +170,11 @@ void Rpc::InitializeReadersAndWriters( google::protobuf::Message, google::protobuf::Message>>( &server_context_); break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + server_async_response_writer_ = cartographer::common::make_unique< + ::grpc::ServerAsyncResponseWriter>( + &server_context_); + break; default: LOG(FATAL) << "RPC type not implemented."; } diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index a8a520f..1616de1 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -36,7 +36,7 @@ namespace framework { class Service; class Rpc { public: - enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE }; + enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE }; struct RpcEvent { const Event event; Rpc* rpc; @@ -79,6 +79,7 @@ class Rpc { RpcEvent new_connection_event_; RpcEvent read_event_; RpcEvent write_event_; + RpcEvent finish_event_; RpcEvent done_event_; std::unique_ptr request_; @@ -86,6 +87,8 @@ class Rpc { std::unique_ptr handler_; + std::unique_ptr<::grpc::ServerAsyncResponseWriter> + server_async_response_writer_; std::unique_ptr<::grpc::ServerAsyncReader> server_async_reader_; diff --git a/cartographer_grpc/framework/rpc_handler_interface.h b/cartographer_grpc/framework/rpc_handler_interface.h index ab0c21b..2979529 100644 --- a/cartographer_grpc/framework/rpc_handler_interface.h +++ b/cartographer_grpc/framework/rpc_handler_interface.h @@ -32,7 +32,7 @@ class RpcHandlerInterface { virtual void SetRpc(Rpc* rpc) = 0; virtual void OnRequestInternal( const ::google::protobuf::Message* request) = 0; - virtual void OnReadsDone() = 0; + virtual void OnReadsDone(){}; }; using RpcHandlerFactory = std::function( @@ -43,6 +43,7 @@ struct RpcHandlerInfo { const google::protobuf::Descriptor* response_descriptor; const RpcHandlerFactory rpc_handler_factory; const grpc::internal::RpcMethod::RpcType rpc_type; + const std::string fully_qualified_name; }; } // namespace framework diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index a339965..98b027c 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -53,6 +54,9 @@ class Server { template void RegisterHandler(const std::string& method_name) { + std::stringstream fully_qualified_name; + fully_qualified_name << "/" << ServiceType::service_full_name() << "/" + << method_name; rpc_handlers_[ServiceType::service_full_name()].emplace( method_name, RpcHandlerInfo{ @@ -66,7 +70,8 @@ class Server { return rpc_handler; }, RpcType::value}); + typename RpcHandlerType::OutgoingType>::value, + fully_qualified_name.str()}); } private: diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 2c77ec2..49d48c4 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -33,7 +33,7 @@ class MathServerContext : public ExecutionContext { int additional_increment() { return 10; } }; -class GetServerOptionsHandler +class GetSumHandler : public RpcHandler, proto::GetSumResponse> { public: void OnRequest(const proto::GetSumRequest& request) override { @@ -51,6 +51,16 @@ class GetServerOptionsHandler int sum_ = 0; }; +class GetSquareHandler + : public RpcHandler { + void OnRequest(const proto::GetSquareRequest& request) override { + auto response = + cartographer::common::make_unique(); + response->set_output(request.input() * request.input()); + Send(std::move(response)); + } +}; + // TODO(cschuet): Due to the hard-coded part these tests will become flaky when // run in parallel. It would be nice to find a way to solve that. gRPC also // allows to communicate over UNIX domain sockets. @@ -63,12 +73,19 @@ class ServerTest : public ::testing::Test { Server::Builder server_builder; server_builder.SetServerAddress(kServerAddress); server_builder.SetNumberOfThreads(kNumThreads); - server_builder.RegisterHandler( - "GetSum"); + server_builder.RegisterHandler("GetSum"); + server_builder.RegisterHandler("GetSquare"); server_ = server_builder.Build(); + + client_channel_ = + grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials()); + stub_ = proto::Math::NewStub(client_channel_); } std::unique_ptr server_; + std::shared_ptr client_channel_; + std::unique_ptr stub_; + grpc::ClientContext client_context_; }; TEST_F(ServerTest, StartAndStopServerTest) { @@ -81,13 +98,9 @@ TEST_F(ServerTest, ProcessRpcStreamTest) { cartographer::common::make_unique()); server_->Start(); - auto channel = - grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials()); - std::unique_ptr stub(proto::Math::NewStub(channel)); - grpc::ClientContext context; proto::GetSumResponse result; std::unique_ptr > writer( - stub->GetSum(&context, &result)); + stub_->GetSum(&client_context_, &result)); for (int i = 0; i < 3; ++i) { proto::GetSumRequest request; request.set_input(i); @@ -101,6 +114,19 @@ TEST_F(ServerTest, ProcessRpcStreamTest) { server_->Shutdown(); } +TEST_F(ServerTest, ProcessUnaryRpcTest) { + server_->Start(); + + proto::GetSquareResponse result; + proto::GetSquareRequest request; + request.set_input(11); + grpc::Status status = stub_->GetSquare(&client_context_, request, &result); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(result.output(), 121); + + server_->Shutdown(); +} + } // namespace } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 76e0d0c..87405c7 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -26,13 +26,11 @@ Service::Service(const std::string& service_name, const std::map& rpc_handler_infos) : rpc_handler_infos_(rpc_handler_infos) { for (const auto& rpc_handler_info : rpc_handler_infos_) { - std::string fully_qualified_method_name = - "/" + service_name + "/" + rpc_handler_info.first; // The 'handler' below is set to 'nullptr' indicating that we want to // handle this method asynchronously. this->AddMethod(new grpc::internal::RpcServiceMethod( - fully_qualified_method_name.c_str(), rpc_handler_info.second.rpc_type, - nullptr /* handler */)); + rpc_handler_info.second.fully_qualified_name.c_str(), + rpc_handler_info.second.rpc_type, nullptr /* handler */)); } } @@ -65,6 +63,9 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) { case Rpc::Event::WRITE: HandleWrite(rpc, ok); break; + case Rpc::Event::FINISH: + HandleFinish(rpc, ok); + break; case Rpc::Event::DONE: HandleDone(rpc, ok); break; @@ -114,12 +115,21 @@ void Service::HandleWrite(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); } +void Service::HandleFinish(Rpc* rpc, bool ok) { + if (!ok) { + LOG(ERROR) << "Finish failed"; + } + + RemoveIfNotPending(rpc); +} + void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); } void Service::RemoveIfNotPending(Rpc* rpc) { if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending && !rpc->GetRpcEvent(Rpc::Event::READ)->pending && - !rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) { + !rpc->GetRpcEvent(Rpc::Event::WRITE)->pending && + !rpc->GetRpcEvent(Rpc::Event::FINISH)->pending) { active_rpcs_.Remove(rpc); } } diff --git a/cartographer_grpc/framework/service.h b/cartographer_grpc/framework/service.h index e54af0c..b3f60a2 100644 --- a/cartographer_grpc/framework/service.h +++ b/cartographer_grpc/framework/service.h @@ -45,6 +45,7 @@ class Service : public ::grpc::Service { void HandleNewConnection(Rpc* rpc, bool ok); void HandleRead(Rpc* rpc, bool ok); void HandleWrite(Rpc* rpc, bool ok); + void HandleFinish(Rpc* rpc, bool ok); void HandleDone(Rpc* rpc, bool ok); void RemoveIfNotPending(Rpc* rpc);