diff --git a/cartographer_grpc/framework/event_queue_thread.h b/cartographer_grpc/framework/event_queue_thread.h index e7d2189..0f497b9 100644 --- a/cartographer_grpc/framework/event_queue_thread.h +++ b/cartographer_grpc/framework/event_queue_thread.h @@ -26,8 +26,6 @@ namespace cartographer_grpc { namespace framework { -using EventQueue = cartographer::common::BlockingQueue; - class EventQueueThread { public: using EventQueueRunner = std::function; diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index 08f111e..e96640e 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -30,7 +30,7 @@ namespace { template void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status, const google::protobuf::Message* msg, - Rpc::RpcEvent* rpc_event) { + Rpc::EventBase* rpc_event) { if (msg) { reader_writer->Finish(*msg, status, rpc_event); } else { @@ -40,6 +40,19 @@ void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status, } // namespace +void Rpc::CompletionQueueRpcEvent::Handle() { + pending = false; + rpc_ptr->service()->HandleEvent(event, rpc_ptr, ok); +} + +void Rpc::InternalRpcEvent::Handle() { + if (auto rpc_shared = rpc.lock()) { + rpc_shared->service()->HandleEvent(event, rpc_shared.get(), true); + } else { + LOG(WARNING) << "Ignoring stale event."; + } +} + Rpc::Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, EventQueue* event_queue, ExecutionContext* execution_context, @@ -52,6 +65,11 @@ Rpc::Rpc(int method_index, rpc_handler_info_(rpc_handler_info), service_(service), weak_ptr_factory_(weak_ptr_factory), + new_connection_event_(Event::NEW_CONNECTION, this), + read_event_(Event::READ, this), + write_event_(Event::WRITE, this), + finish_event_(Event::FINISH, this), + done_event_(Event::DONE, this), handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { InitializeReadersAndWriters(rpc_handler_info_.rpc_type); @@ -81,8 +99,7 @@ void Rpc::RequestNextMethodInvocation() { SetRpcEventState(Event::DONE, true); // TODO(gaschler): Asan reports direct leak of this new from both calls // StartServing and HandleNewConnection. - server_context_.AsyncNotifyWhenDone( - new RpcEvent{Event::DONE, weak_ptr_factory_(this), true}); + server_context_.AsyncNotifyWhenDone(GetRpcEvent(Event::DONE)); // Make sure after terminating the connection, gRPC notifies us with this // event. @@ -92,27 +109,25 @@ void Rpc::RequestNextMethodInvocation() { service_->RequestAsyncBidiStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); + GetRpcEvent(Event::NEW_CONNECTION)); break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: service_->RequestAsyncClientStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); + GetRpcEvent(Event::NEW_CONNECTION)); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: service_->RequestAsyncUnary( method_index_, &server_context_, request_.get(), streaming_interface(), server_completion_queue_, - server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); + server_completion_queue_, GetRpcEvent(Event::NEW_CONNECTION)); break; case ::grpc::internal::RpcMethod::SERVER_STREAMING: service_->RequestAsyncServerStreaming( method_index_, &server_context_, request_.get(), streaming_interface(), server_completion_queue_, - server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true}); + server_completion_queue_, GetRpcEvent(Event::NEW_CONNECTION)); break; } } @@ -123,9 +138,7 @@ void Rpc::RequestStreamingReadIfNeeded() { case ::grpc::internal::RpcMethod::BIDI_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING: SetRpcEventState(Event::READ, true); - async_reader_interface()->Read( - request_.get(), - new RpcEvent{Event::READ, weak_ptr_factory_(this), true}); + async_reader_interface()->Read(request_.get(), GetRpcEvent(Event::READ)); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::SERVER_STREAMING: @@ -140,14 +153,14 @@ void Rpc::RequestStreamingReadIfNeeded() { void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { EnqueueMessage(SendItem{std::move(message), ::grpc::Status::OK}); - event_queue_->Push( - new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true}); + event_queue_->Push(UniqueEventPtr( + new InternalRpcEvent(Event::WRITE_NEEDED, weak_ptr_factory_(this)))); } void Rpc::Finish(::grpc::Status status) { EnqueueMessage(SendItem{nullptr /* message */, status}); - event_queue_->Push( - new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true}); + event_queue_->Push(UniqueEventPtr( + new InternalRpcEvent(Event::WRITE_NEEDED, weak_ptr_factory_(this)))); } void Rpc::HandleSendQueue() { @@ -218,24 +231,29 @@ Rpc::async_writer_interface() { LOG(FATAL) << "Never reached."; } -bool* Rpc::GetRpcEventState(Event event) { +Rpc::CompletionQueueRpcEvent* Rpc::GetRpcEvent(Event event) { switch (event) { case Event::NEW_CONNECTION: - return &new_connection_event_pending_; + return &new_connection_event_; case Event::READ: - return &read_event_pending_; + return &read_event_; case Event::WRITE_NEEDED: - return &write_needed_event_pending_; + LOG(FATAL) << "Rpc does not store Event::WRITE_NEEDED."; + break; case Event::WRITE: - return &write_event_pending_; + return &write_event_; case Event::FINISH: - return &finish_event_pending_; + return &finish_event_; case Event::DONE: - return &done_event_pending_; + return &done_event_; } LOG(FATAL) << "Never reached."; } +bool* Rpc::GetRpcEventState(Event event) { + return &GetRpcEvent(event)->pending; +} + void Rpc::EnqueueMessage(SendItem&& send_item) { cartographer::common::MutexLocker locker(&send_queue_lock_); send_queue_.emplace(std::move(send_item)); @@ -247,25 +265,21 @@ void Rpc::PerformFinish(std::unique_ptr<::google::protobuf::Message> message, switch (rpc_handler_info_.rpc_type) { case ::grpc::internal::RpcMethod::BIDI_STREAMING: CHECK(!message); - server_async_reader_writer_->Finish( - status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + server_async_reader_writer_->Finish(status, GetRpcEvent(Event::FINISH)); break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: response_ = std::move(message); - SendUnaryFinish( - server_async_reader_.get(), status, response_.get(), - new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + SendUnaryFinish(server_async_reader_.get(), status, response_.get(), + GetRpcEvent(Event::FINISH)); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: response_ = std::move(message); - SendUnaryFinish( - server_async_response_writer_.get(), status, response_.get(), - new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + SendUnaryFinish(server_async_response_writer_.get(), status, + response_.get(), GetRpcEvent(Event::FINISH)); break; case ::grpc::internal::RpcMethod::SERVER_STREAMING: CHECK(!message); - server_async_writer_->Finish( - status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true}); + server_async_writer_->Finish(status, GetRpcEvent(Event::FINISH)); break; } } @@ -278,11 +292,12 @@ void Rpc::PerformWrite(std::unique_ptr<::google::protobuf::Message> message, ::grpc::internal::RpcMethod::CLIENT_STREAMING); SetRpcEventState(Event::WRITE, true); response_ = std::move(message); - async_writer_interface()->Write( - *response_, new RpcEvent{Event::WRITE, weak_ptr_factory_(this), true}); + async_writer_interface()->Write(*response_, GetRpcEvent(Event::WRITE)); } void Rpc::SetRpcEventState(Event event, bool pending) { + // TODO(gaschler): Since the only usage is setting this true at creation, + // consider removing this method. *GetRpcEventState(event) = pending; } diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index acee225..3843dda 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -39,8 +39,6 @@ class Service; // TODO(cschuet): Add a unittest that tests the logic of this class. class Rpc { public: - struct RpcEvent; - using EventQueue = cartographer::common::BlockingQueue; using WeakPtrFactory = std::function(Rpc*)>; enum class Event { NEW_CONNECTION = 0, @@ -50,11 +48,61 @@ class Rpc { FINISH, DONE }; - struct RpcEvent { + + struct EventBase { + explicit EventBase(Event event) : event(event) {} + virtual ~EventBase(){}; + virtual void Handle() = 0; + const Event event; - std::weak_ptr rpc; - bool ok; }; + + class EventDeleter { + public: + enum Action { DELETE = 0, DO_NOT_DELETE }; + + // The default action 'DELETE' is used implicitly, for instance for a + // new UniqueEventPtr or a UniqueEventPtr that is created by + // 'return nullptr'. + EventDeleter() : action_(DELETE) {} + explicit EventDeleter(Action action) : action_(action) {} + void operator()(EventBase* e) { + if (e != nullptr && action_ == DELETE) { + delete e; + } + } + + private: + Action action_; + }; + + using UniqueEventPtr = std::unique_ptr; + using EventQueue = cartographer::common::BlockingQueue; + + // Flows through gRPC's CompletionQueue and then our EventQueue. + struct CompletionQueueRpcEvent : public EventBase { + CompletionQueueRpcEvent(Event event, Rpc* rpc) + : EventBase(event), rpc_ptr(rpc), ok(false), pending(false) {} + void PushToEventQueue() { + rpc_ptr->event_queue()->Push( + UniqueEventPtr(this, EventDeleter(EventDeleter::DO_NOT_DELETE))); + } + void Handle() override; + + Rpc* rpc_ptr; + bool ok; + bool pending; + }; + + // Flows only through our EventQueue. + struct InternalRpcEvent : public EventBase { + InternalRpcEvent(Event event, std::weak_ptr rpc) + : EventBase(event), rpc(rpc) {} + void Handle() override; + + std::weak_ptr rpc; + }; + Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, EventQueue* event_queue, ExecutionContext* execution_context, const RpcHandlerInfo& rpc_handler_info, Service* service, @@ -69,7 +117,6 @@ class Rpc { void Write(std::unique_ptr<::google::protobuf::Message> message); void Finish(::grpc::Status status); Service* service() { return service_; } - void SetRpcEventState(Event event, bool pending); bool IsRpcEventPending(Event event); bool IsAnyEventPending(); void SetEventQueue(EventQueue* event_queue) { event_queue_ = event_queue; } @@ -86,7 +133,9 @@ class Rpc { Rpc& operator=(const Rpc&) = delete; void InitializeReadersAndWriters( ::grpc::internal::RpcMethod::RpcType rpc_type); + CompletionQueueRpcEvent* GetRpcEvent(Event event); bool* GetRpcEventState(Event event); + void SetRpcEventState(Event event, bool pending); void EnqueueMessage(SendItem&& send_item); void PerformFinish(std::unique_ptr<::google::protobuf::Message> message, ::grpc::Status status); @@ -109,16 +158,11 @@ class Rpc { WeakPtrFactory weak_ptr_factory_; ::grpc::ServerContext server_context_; - // These state variables indicate whether the corresponding event is currently - // pending completion, e.g. 'read_event_pending_ = true' means that a read has - // been requested but hasn't completed yet. 'read_event_pending_ = false' - // indicates that the read has completed and currently no read is in-flight. - bool new_connection_event_pending_ = false; - bool read_event_pending_ = false; - bool write_needed_event_pending_ = false; - bool write_event_pending_ = false; - bool finish_event_pending_ = false; - bool done_event_pending_ = false; + CompletionQueueRpcEvent new_connection_event_; + CompletionQueueRpcEvent read_event_; + CompletionQueueRpcEvent write_event_; + CompletionQueueRpcEvent finish_event_; + CompletionQueueRpcEvent done_event_; std::unique_ptr request_; std::unique_ptr response_; @@ -140,6 +184,8 @@ class Rpc { std::queue send_queue_; }; +using EventQueue = Rpc::EventQueue; + // This class keeps track of all in-flight RPCs for a 'Service'. Make sure that // all RPCs have been terminated and removed from this object before it goes out // of scope. diff --git a/cartographer_grpc/framework/server.cc b/cartographer_grpc/framework/server.cc index adb1042..2c10462 100644 --- a/cartographer_grpc/framework/server.cc +++ b/cartographer_grpc/framework/server.cc @@ -80,25 +80,12 @@ void Server::RunCompletionQueue( bool ok; void* tag; while (completion_queue->Next(&tag, &ok)) { - auto* rpc_event = static_cast(tag); + auto* rpc_event = static_cast(tag); rpc_event->ok = ok; - if (auto rpc = rpc_event->rpc.lock()) { - rpc->event_queue()->Push(rpc_event); - } else { - LOG(WARNING) << "Ignoring stale event."; - } + rpc_event->PushToEventQueue(); } } -void Server::ProcessRpcEvent(Rpc::RpcEvent* rpc_event) { - if (auto rpc = rpc_event->rpc.lock()) { - rpc->service()->HandleEvent(rpc_event->event, rpc.get(), rpc_event->ok); - } else { - LOG(WARNING) << "Ignoring stale event."; - } - delete rpc_event; -} - EventQueue* Server::SelectNextEventQueueRoundRobin() { cartographer::common::MutexLocker locker(¤t_event_queue_id_lock_); current_event_queue_id_ = @@ -108,16 +95,17 @@ EventQueue* Server::SelectNextEventQueueRoundRobin() { void Server::RunEventQueue(EventQueue* event_queue) { while (!shutting_down_) { - Rpc::RpcEvent* rpc_event = event_queue->PopWithTimeout(kPopEventTimeout); + Rpc::UniqueEventPtr rpc_event = + event_queue->PopWithTimeout(kPopEventTimeout); if (rpc_event) { - ProcessRpcEvent(rpc_event); + rpc_event->Handle(); } } // Finish processing the rest of the items. - while (Rpc::RpcEvent* rpc_event = + while (Rpc::UniqueEventPtr rpc_event = event_queue->PopWithTimeout(kPopEventTimeout)) { - ProcessRpcEvent(rpc_event); + rpc_event->Handle(); } } diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index 92fe51e..26b9ffb 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -112,9 +112,8 @@ class Server { const std::string& service_name, const std::map& rpc_handler_infos); void RunCompletionQueue(::grpc::ServerCompletionQueue* completion_queue); - void RunEventQueue(EventQueue* event_queue); - void ProcessRpcEvent(Rpc::RpcEvent* rpc_event); - EventQueue* SelectNextEventQueueRoundRobin(); + void RunEventQueue(Rpc::EventQueue* event_queue); + Rpc::EventQueue* SelectNextEventQueueRoundRobin(); Options options_; diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 69c49ec..875b2d7 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -58,7 +58,6 @@ void Service::StartServing( void Service::StopServing() { shutting_down_ = true; } void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) { - rpc->SetRpcEventState(event, false); switch (event) { case Rpc::Event::NEW_CONNECTION: HandleNewConnection(rpc, ok);