diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index 5dd5690..a945fd7 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -78,7 +78,7 @@ void Rpc::RequestNextMethodInvocation() { // Ask gRPC to notify us when the connection terminates. SetRpcEventState(Event::DONE, true); server_context_.AsyncNotifyWhenDone( - new RpcEvent{Event::DONE, weak_ptr_factory_(this)}); + new RpcEvent{Event::DONE, weak_ptr_factory_(this), true}); // Make sure after terminating the connection, gRPC notifies us with this // event. @@ -88,20 +88,20 @@ void Rpc::RequestNextMethodInvocation() { service_->RequestAsyncBidiStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: service_->RequestAsyncClientStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: service_->RequestAsyncUnary( method_index_, &server_context_, request_.get(), streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); break; default: LOG(FATAL) << "RPC type not implemented."; @@ -115,7 +115,8 @@ void Rpc::RequestStreamingReadIfNeeded() { case ::grpc::internal::RpcMethod::CLIENT_STREAMING: SetRpcEventState(Event::READ, true); async_reader_interface()->Read( - request_.get(), new RpcEvent{Event::READ, weak_ptr_factory_(this)}); + request_.get(), + new RpcEvent{Event::READ, weak_ptr_factory_(this), true}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: // For NORMAL_RPC we don't have to do anything here, since gRPC @@ -130,86 +131,37 @@ void Rpc::RequestStreamingReadIfNeeded() { } void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { - switch (rpc_handler_info_.rpc_type) { - case ::grpc::internal::RpcMethod::BIDI_STREAMING: - // For BIDI_STREAMING enqueue the message into the send queue and - // start write operations if none are currently in flight. - send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK}); - PerformWriteIfNeeded(); - break; - case ::grpc::internal::RpcMethod::CLIENT_STREAMING: - SendFinish(std::move(message), ::grpc::Status::OK); - break; - case ::grpc::internal::RpcMethod::NORMAL_RPC: - SendFinish(std::move(message), ::grpc::Status::OK); - break; - default: - LOG(FATAL) << "RPC type not implemented."; - } -} - -void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message, - ::grpc::Status status) { - SetRpcEventState(Event::FINISH, true); - switch (rpc_handler_info_.rpc_type) { - case ::grpc::internal::RpcMethod::BIDI_STREAMING: - CHECK(!message); - server_async_reader_writer_->Finish( - status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); - break; - case ::grpc::internal::RpcMethod::CLIENT_STREAMING: - response_ = std::move(message); - SendUnaryFinish(server_async_reader_.get(), status, response_.get(), - new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); - break; - case ::grpc::internal::RpcMethod::NORMAL_RPC: - response_ = std::move(message); - SendUnaryFinish(server_async_response_writer_.get(), status, - response_.get(), - new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); - break; - default: - LOG(FATAL) << "RPC type not implemented."; - } + EnqueueMessage(SendItem{std::move(message), ::grpc::Status::OK}); + event_queue_->Push( + new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true}); } void Rpc::Finish(::grpc::Status status) { - switch (rpc_handler_info_.rpc_type) { - case ::grpc::internal::RpcMethod::BIDI_STREAMING: - send_queue_.emplace(SendItem{nullptr /* msg */, status}); - PerformWriteIfNeeded(); - break; - case ::grpc::internal::RpcMethod::CLIENT_STREAMING: - SendFinish(nullptr /* message */, status); - break; - case ::grpc::internal::RpcMethod::NORMAL_RPC: - SendFinish(nullptr /* message */, status); - break; - default: - LOG(FATAL) << "RPC type not implemented."; - } + EnqueueMessage(SendItem{nullptr /* message */, status}); + event_queue_->Push( + new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true}); } -void Rpc::PerformWriteIfNeeded() { - if (send_queue_.empty() || IsRpcEventPending(Event::WRITE)) { +void Rpc::HandleSendQueue() { + SendItem send_item; + { + cartographer::common::MutexLocker locker(&send_queue_lock_); + if (send_queue_.empty() || IsRpcEventPending(Event::WRITE) || + IsRpcEventPending(Event::FINISH)) { + return; + } + + send_item = std::move(send_queue_.front()); + send_queue_.pop(); + } + if (!send_item.msg || + rpc_handler_info_.rpc_type == ::grpc::internal::RpcMethod::NORMAL_RPC || + rpc_handler_info_.rpc_type == + ::grpc::internal::RpcMethod::CLIENT_STREAMING) { + PerformFinish(std::move(send_item.msg), send_item.status); return; } - - // Make sure not other send operations are in-flight. - CHECK(!IsRpcEventPending(Event::FINISH)); - - SendItem send_item = std::move(send_queue_.front()); - send_queue_.pop(); - response_ = std::move(send_item.msg); - - if (response_) { - SetRpcEventState(Event::WRITE, true); - async_writer_interface()->Write( - *response_.get(), new RpcEvent{Event::WRITE, weak_ptr_factory_(this)}); - } else { - CHECK(send_queue_.empty()); - SendFinish(nullptr /* message */, send_item.status); - } + PerformWrite(std::move(send_item.msg), send_item.status); } ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { @@ -259,20 +211,65 @@ Rpc::async_writer_interface() { bool* Rpc::GetRpcEventState(Event event) { switch (event) { - case Event::DONE: - return &done_event_pending_; - case Event::FINISH: - return &finish_event_pending_; case Event::NEW_CONNECTION: return &new_connection_event_pending_; case Event::READ: return &read_event_pending_; + case Event::WRITE_NEEDED: + return &write_needed_event_pending_; case Event::WRITE: return &write_event_pending_; + case Event::FINISH: + return &finish_event_pending_; + case Event::DONE: + return &done_event_pending_; } LOG(FATAL) << "Never reached."; } +void Rpc::EnqueueMessage(SendItem&& send_item) { + cartographer::common::MutexLocker locker(&send_queue_lock_); + send_queue_.emplace(std::move(send_item)); +} + +void Rpc::PerformFinish(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status) { + SetRpcEventState(Event::FINISH, true); + switch (rpc_handler_info_.rpc_type) { + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + CHECK(!message); + server_async_reader_writer_->Finish( + status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + break; + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + response_ = std::move(message); + SendUnaryFinish( + server_async_reader_.get(), status, response_.get(), + new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + break; + case ::grpc::internal::RpcMethod::NORMAL_RPC: + response_ = std::move(message); + SendUnaryFinish( + server_async_response_writer_.get(), status, response_.get(), + new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + break; + default: + LOG(FATAL) << "RPC type not implemented."; + } +} + +void Rpc::PerformWrite(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status) { + CHECK(message) << "PerformWrite must be called with a non-null message"; + CHECK_NE(rpc_handler_info_.rpc_type, ::grpc::internal::RpcMethod::NORMAL_RPC); + CHECK_NE(rpc_handler_info_.rpc_type, + ::grpc::internal::RpcMethod::CLIENT_STREAMING); + SetRpcEventState(Event::WRITE, true); + response_ = std::move(message); + async_writer_interface()->Write( + *response_, new RpcEvent{Event::WRITE, weak_ptr_factory_(this), true}); +} + void Rpc::SetRpcEventState(Event event, bool pending) { *GetRpcEventState(event) = pending; } diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 0c2ddb4..5b83909 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -36,12 +36,20 @@ namespace cartographer_grpc { namespace framework { class Service; +// TODO(cschuet): Add a unittest that tests the logic of this class. class Rpc { public: struct RpcEvent; using EventQueue = cartographer::common::BlockingQueue; using WeakPtrFactory = std::function(Rpc*)>; - enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE }; + enum class Event { + NEW_CONNECTION = 0, + READ, + WRITE_NEEDED, + WRITE, + FINISH, + DONE + }; struct RpcEvent { const Event event; std::weak_ptr rpc; @@ -56,7 +64,7 @@ class Rpc { void OnReadsDone(); void RequestNextMethodInvocation(); void RequestStreamingReadIfNeeded(); - void PerformWriteIfNeeded(); + void HandleSendQueue(); void Write(std::unique_ptr<::google::protobuf::Message> message); void Finish(::grpc::Status status); Service* service() { return service_; } @@ -76,9 +84,12 @@ class Rpc { Rpc& operator=(const Rpc&) = delete; void InitializeReadersAndWriters( ::grpc::internal::RpcMethod::RpcType rpc_type); - void SendFinish(std::unique_ptr<::google::protobuf::Message> message, - ::grpc::Status status); bool* GetRpcEventState(Event event); + void EnqueueMessage(SendItem&& send_item); + void PerformFinish(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status); + void PerformWrite(std::unique_ptr<::google::protobuf::Message> message, + ::grpc::Status status); ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* async_reader_interface(); @@ -102,6 +113,7 @@ class Rpc { // indicates that the read has completed and currently no read is in-flight. bool new_connection_event_pending_ = false; bool read_event_pending_ = false; + bool write_needed_event_pending_ = false; bool write_event_pending_ = false; bool finish_event_pending_ = false; bool done_event_pending_ = false; @@ -120,6 +132,7 @@ class Rpc { google::protobuf::Message>> server_async_reader_writer_; + cartographer::common::Mutex send_queue_lock_; std::queue send_queue_; }; diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index f40ecb9..d7027e8 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -66,6 +66,7 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) { case Rpc::Event::READ: HandleRead(rpc, ok); break; + case Rpc::Event::WRITE_NEEDED: case Rpc::Event::WRITE: HandleWrite(rpc, ok); break; @@ -123,7 +124,7 @@ void Service::HandleWrite(Rpc* rpc, bool ok) { } // Send the next message or potentially finish the connection. - rpc->PerformWriteIfNeeded(); + rpc->HandleSendQueue(); RemoveIfNotPending(rpc); }