diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto index 64bb3a8..95e8d03 100644 --- a/cartographer_grpc/framework/proto/math_service.proto +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -40,10 +40,19 @@ message GetEchoResponse { int32 output = 1; } +message GetSequenceRequest { + int32 input = 1; +} + +message GetSequenceResponse { + int32 output = 1; +} + // Provides information about the gRPC server. service Math { rpc GetSum(stream GetSumRequest) returns (GetSumResponse); rpc GetSquare(GetSquareRequest) returns (GetSquareResponse); rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse); rpc GetEcho(GetEchoRequest) returns (GetEchoResponse); + rpc GetSequence(GetSequenceRequest) returns (stream GetSequenceResponse); } diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index e539966..860ea31 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -103,8 +103,13 @@ void Rpc::RequestNextMethodInvocation() { server_completion_queue_, new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); break; - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + service_->RequestAsyncServerStreaming( + method_index_, &server_context_, request_.get(), + streaming_interface(), server_completion_queue_, + server_completion_queue_, + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); + break; } } @@ -119,14 +124,13 @@ void Rpc::RequestStreamingReadIfNeeded() { new RpcEvent{Event::READ, weak_ptr_factory_(this), true}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: - // For NORMAL_RPC we don't have to do anything here, since gRPC - // automatically issues a READ request and places the request into the - // 'Message' we provided to 'RequestAsyncUnary' above. + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + // For NORMAL_RPC and SERVER_STREAMING we don't need to queue an event, + // since gRPC automatically issues a READ request and places the request + // into the 'Message' we provided to 'RequestAsyncUnary' above. OnRequest(); OnReadsDone(); break; - default: - LOG(FATAL) << "RPC type not implemented."; } } @@ -172,8 +176,8 @@ void Rpc::HandleSendQueue() { return server_async_reader_.get(); case ::grpc::internal::RpcMethod::NORMAL_RPC: return server_async_response_writer_.get(); - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + return server_async_writer_.get(); } LOG(FATAL) << "Never reached."; } @@ -187,8 +191,9 @@ Rpc::async_reader_interface() { return server_async_reader_.get(); case ::grpc::internal::RpcMethod::NORMAL_RPC: LOG(FATAL) << "For NORMAL_RPC no streaming reader interface exists."; - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + LOG(FATAL) + << "For SERVER_STREAMING no streaming reader interface exists."; } LOG(FATAL) << "Never reached."; } @@ -203,8 +208,8 @@ Rpc::async_writer_interface() { LOG(FATAL) << "For NORMAL_RPC and CLIENT_STREAMING no streaming writer " "interface exists."; break; - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + return server_async_writer_.get(); } LOG(FATAL) << "Never reached."; } @@ -253,8 +258,11 @@ void Rpc::PerformFinish(std::unique_ptr<::google::protobuf::Message> message, server_async_response_writer_.get(), status, response_.get(), new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); break; - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + CHECK(!message); + server_async_writer_->Finish( + status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + break; } } @@ -307,8 +315,11 @@ void Rpc::InitializeReadersAndWriters( ::grpc::ServerAsyncResponseWriter>( &server_context_); break; - default: - LOG(FATAL) << "RPC type not implemented."; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + server_async_writer_ = cartographer::common::make_unique< + ::grpc::ServerAsyncWriter>( + &server_context_); + break; } } diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 95b66f3..8f8eb74 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -132,6 +132,8 @@ class Rpc { std::unique_ptr<::grpc::ServerAsyncReaderWriter> server_async_reader_writer_; + std::unique_ptr<::grpc::ServerAsyncWriter> + server_async_writer_; cartographer::common::Mutex send_queue_lock_; std::queue send_queue_; diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 4ec1f71..d0c7d42 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -101,6 +101,21 @@ class GetEchoHandler } }; +class GetSequenceHandler + : public RpcHandler> { + public: + void OnRequest(const proto::GetSequenceRequest& request) override { + for (int i = 0; i < request.input(); ++i) { + auto response = + cartographer::common::make_unique(); + response->set_output(i); + Send(std::move(response)); + } + Finish(::grpc::Status::OK); + } +}; + // TODO(cschuet): Due to the hard-coded part these tests will become flaky when // run in parallel. It would be nice to find a way to solve that. gRPC also // allows to communicate over UNIX domain sockets. @@ -119,6 +134,8 @@ class ServerTest : public ::testing::Test { server_builder.RegisterHandler( "GetRunningSum"); server_builder.RegisterHandler("GetEcho"); + server_builder.RegisterHandler( + "GetSequence"); server_ = server_builder.Build(); client_channel_ = @@ -189,6 +206,7 @@ TEST_F(ServerTest, ProcessBidiStreamingRpcTest) { expected_responses.pop_front(); } EXPECT_TRUE(expected_responses.empty()); + EXPECT_TRUE(reader_writer->Finish().ok()); server_->Shutdown(); } @@ -219,6 +237,24 @@ TEST_F(ServerTest, WriteFromOtherThread) { server_->Shutdown(); } +TEST_F(ServerTest, ProcessServerStreamingRpcTest) { + server_->Start(); + + proto::GetSequenceRequest request; + request.set_input(12); + auto reader = stub_->GetSequence(&client_context_, request); + + proto::GetSequenceResponse response; + for (int i = 0; i < 12; ++i) { + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.output(), i); + } + EXPECT_FALSE(reader->Read(&response)); + EXPECT_TRUE(reader->Finish().ok()); + + server_->Shutdown(); +} + } // namespace } // namespace framework } // namespace cartographer_grpc