diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto index 18a182b..474e1ea 100644 --- a/cartographer_grpc/framework/proto/math_service.proto +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -36,4 +36,5 @@ message GetSquareResponse { service Math { rpc GetSum(stream GetSumRequest) returns (GetSumResponse); rpc GetSquare(GetSquareRequest) returns (GetSquareResponse); + rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse); } diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index b9d400b..9f6a67c 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -22,6 +22,23 @@ namespace cartographer_grpc { namespace framework { +namespace { + +// Finishes the gRPC for non-streaming response RPCs, i.e. NORMAL_RPC and +// CLIENT_STREAMING. If no 'msg' is passed, we signal an error to the client as +// the server is not honoring the gRPC call signature. +template +void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status, + const google::protobuf::Message* msg, + Rpc::RpcEvent* rpc_event) { + if (msg) { + reader_writer->Finish(*msg, status, rpc_event); + } else { + reader_writer->FinishWithError(status, rpc_event); + } +} + +} // namespace Rpc::Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, @@ -60,10 +77,20 @@ void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } void Rpc::OnReadsDone() { handler_->OnReadsDone(); } void Rpc::RequestNextMethodInvocation() { + // Ask gRPC to notify us when the connection terminates. done_event_.pending = true; - new_connection_event_.pending = true; server_context_.AsyncNotifyWhenDone(&done_event_); + + // Make sure after terminating the connection, gRPC notifies us with this + // event. + new_connection_event_.pending = true; switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + service_->RequestAsyncBidiStreaming( + method_index_, &server_context_, streaming_interface(), + server_completion_queue_, server_completion_queue_, + &new_connection_event_); + break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: service_->RequestAsyncClientStreaming( method_index_, &server_context_, streaming_interface(), @@ -84,6 +111,7 @@ void Rpc::RequestNextMethodInvocation() { void Rpc::RequestStreamingReadIfNeeded() { // For request-streaming RPCs ask the client to start sending requests. switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING: read_event_.pending = true; async_reader_interface()->Read(request_.get(), &read_event_); @@ -101,25 +129,89 @@ void Rpc::RequestStreamingReadIfNeeded() { } void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { - response_ = std::move(message); switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + // For BIDI_STREAMING enqueue the message into the send queue and + // start write operations if none are currently in flight. + send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK}); + PerformWriteIfNeeded(); + break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: - server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, - &finish_event_); - finish_event_.pending = true; + SendFinish(std::move(message), ::grpc::Status::OK); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: - server_async_response_writer_->Finish(*response_.get(), - ::grpc::Status::OK, &finish_event_); - finish_event_.pending = true; + SendFinish(std::move(message), ::grpc::Status::OK); break; default: LOG(FATAL) << "RPC type not implemented."; } } +void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status) { + finish_event_.pending = true; + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + CHECK(!message); + server_async_reader_writer_->Finish(status, &finish_event_); + break; + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + response_ = std::move(message); + SendUnaryFinish(server_async_reader_.get(), status, response_.get(), + &finish_event_); + break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + response_ = std::move(message); + SendUnaryFinish(server_async_response_writer_.get(), status, + response_.get(), &finish_event_); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + +void Rpc::Finish(::grpc::Status status) { + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + send_queue_.emplace(SendItem{nullptr /* msg */, status}); + PerformWriteIfNeeded(); + break; + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + SendFinish(nullptr /* message */, status); + break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + SendFinish(nullptr /* message */, status); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + +void Rpc::PerformWriteIfNeeded() { + if (send_queue_.empty() || write_event_.pending) { + return; + } + + // Make sure not other send operations are in-flight. + CHECK(!finish_event_.pending); + + SendItem send_item = std::move(send_queue_.front()); + send_queue_.pop(); + response_ = std::move(send_item.msg); + + if (response_) { + write_event_.pending = true; + async_writer_interface()->Write(*response_.get(), &write_event_); + } else { + CHECK(send_queue_.empty()); + SendFinish(nullptr /* message */, send_item.status); + } +} + ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + return server_async_reader_writer_.get(); case ::grpc::internal::RpcMethod::CLIENT_STREAMING: return server_async_reader_.get(); case ::grpc::internal::RpcMethod::NORMAL_RPC: @@ -133,10 +225,28 @@ void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* Rpc::async_reader_interface() { switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + return server_async_reader_writer_.get(); case ::grpc::internal::RpcMethod::CLIENT_STREAMING: return server_async_reader_.get(); case ::grpc::internal::RpcMethod::NORMAL_RPC: - LOG(FATAL) << "For NORMAL_RPC no streaming interface exists."; + LOG(FATAL) << "For NORMAL_RPC no streaming reader interface exists."; + default: + LOG(FATAL) << "RPC type not implemented."; + } + LOG(FATAL) << "Never reached."; +} + +::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>* +Rpc::async_writer_interface() { + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + return server_async_reader_writer_.get(); + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + case ::grpc::internal::RpcMethod::NORMAL_RPC: + LOG(FATAL) << "For NORMAL_RPC and CLIENT_STREAMING no streaming writer " + "interface exists."; + break; default: LOG(FATAL) << "RPC type not implemented."; } @@ -145,16 +255,16 @@ Rpc::async_reader_interface() { Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) { switch (event) { + case Event::DONE: + return &done_event_; + case Event::FINISH: + return &finish_event_; case Event::NEW_CONNECTION: return &new_connection_event_; case Event::READ: return &read_event_; case Event::WRITE: return &write_event_; - case Event::FINISH: - return &finish_event_; - case Event::DONE: - return &done_event_; } LOG(FATAL) << "Never reached."; } @@ -164,6 +274,12 @@ ActiveRpcs::ActiveRpcs() : lock_() {} void Rpc::InitializeReadersAndWriters( ::grpc::internal::RpcMethod::RpcType rpc_type) { switch (rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + server_async_reader_writer_ = + cartographer::common::make_unique<::grpc::ServerAsyncReaderWriter< + google::protobuf::Message, google::protobuf::Message>>( + &server_context_); + break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: server_async_reader_ = cartographer::common::make_unique<::grpc::ServerAsyncReader< diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 1616de1..5912a27 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -18,6 +18,7 @@ #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_H #include +#include #include #include "cartographer/common/mutex.h" @@ -55,18 +56,30 @@ class Rpc { void OnReadsDone(); void RequestNextMethodInvocation(); void RequestStreamingReadIfNeeded(); + void PerformWriteIfNeeded(); void Write(std::unique_ptr<::google::protobuf::Message> message); + void Finish(::grpc::Status status); Service* service() { return service_; } RpcEvent* GetRpcEvent(Event event); private: + struct SendItem { + std::unique_ptr msg; + ::grpc::Status status; + }; + Rpc(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete; void InitializeReadersAndWriters( ::grpc::internal::RpcMethod::RpcType rpc_type); + void SendFinish(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status); ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* async_reader_interface(); + ::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>* + async_writer_interface(); + ::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); int method_index_; @@ -92,6 +105,11 @@ class Rpc { std::unique_ptr<::grpc::ServerAsyncReader> server_async_reader_; + std::unique_ptr<::grpc::ServerAsyncReaderWriter> + server_async_reader_writer_; + + std::queue send_queue_; }; // This class keeps track of all in-flight RPCs for a 'Service'. Make sure that diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index aab415c..733f66d 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -45,6 +45,7 @@ class RpcHandler : public RpcHandlerInterface { OnRequest(static_cast(*request)); } virtual void OnRequest(const RequestType& request) = 0; + void Finish(::grpc::Status status) { rpc_->Finish(status); } void Send(std::unique_ptr response) { rpc_->Write(std::move(response)); } diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 49d48c4..f4a9a4a 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -51,6 +51,29 @@ class GetSumHandler int sum_ = 0; }; +class GetRunningSumHandler + : public RpcHandler, Stream> { + public: + void OnRequest(const proto::GetSumRequest& request) override { + sum_ += request.input(); + + // Respond twice to demonstrate bidirectional streaming. + auto response = cartographer::common::make_unique(); + response->set_output(sum_); + Send(std::move(response)); + response = cartographer::common::make_unique(); + response->set_output(sum_); + Send(std::move(response)); + } + + void OnReadsDone() override { + Finish(::grpc::Status::OK); + } + + private: + int sum_ = 0; +}; + class GetSquareHandler : public RpcHandler { void OnRequest(const proto::GetSquareRequest& request) override { @@ -75,6 +98,7 @@ class ServerTest : public ::testing::Test { server_builder.SetNumberOfThreads(kNumThreads); server_builder.RegisterHandler("GetSum"); server_builder.RegisterHandler("GetSquare"); + server_builder.RegisterHandler("GetRunningSum"); server_ = server_builder.Build(); client_channel_ = @@ -127,6 +151,28 @@ TEST_F(ServerTest, ProcessUnaryRpcTest) { server_->Shutdown(); } +TEST_F(ServerTest, ProcessBidiStreamingRpcTest) { + server_->Start(); + + auto reader_writer = stub_->GetRunningSum(&client_context_); + for (int i = 0; i < 3; ++i) { + proto::GetSumRequest request; + request.set_input(i); + EXPECT_TRUE(reader_writer->Write(request)); + } + reader_writer->WritesDone(); + proto::GetSumResponse response; + + std::list expected_responses = {0, 0, 1, 1, 3, 3}; + while (reader_writer->Read(&response)) { + EXPECT_EQ(expected_responses.front(), response.output()); + expected_responses.pop_front(); + } + EXPECT_TRUE(expected_responses.empty()); + + server_->Shutdown(); +} + } // namespace } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 87405c7..7496e04 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -105,6 +105,8 @@ void Service::HandleRead(Rpc* rpc, bool ok) { // Reads completed. rpc->OnReadsDone(); + + RemoveIfNotPending(rpc); } void Service::HandleWrite(Rpc* rpc, bool ok) { @@ -112,6 +114,9 @@ void Service::HandleWrite(Rpc* rpc, bool ok) { LOG(ERROR) << "Write failed"; } + // Send the next message or potentially finish the connection. + rpc->PerformWriteIfNeeded(); + RemoveIfNotPending(rpc); }