diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index 4f0d3d3..e52d609 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -43,12 +43,14 @@ void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status, Rpc::Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, ExecutionContext* execution_context, - const RpcHandlerInfo& rpc_handler_info, Service* service) + const RpcHandlerInfo& rpc_handler_info, Service* service, + WeakPtrFactory weak_ptr_factory) : method_index_(method_index), server_completion_queue_(server_completion_queue), execution_context_(execution_context), rpc_handler_info_(rpc_handler_info), service_(service), + weak_ptr_factory_(weak_ptr_factory), handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { InitializeReadersAndWriters(rpc_handler_info_.rpc_type); @@ -64,7 +66,7 @@ Rpc::Rpc(int method_index, std::unique_ptr Rpc::Clone() { return cartographer::common::make_unique( method_index_, server_completion_queue_, execution_context_, - rpc_handler_info_, service_); + rpc_handler_info_, service_, weak_ptr_factory_); } void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } @@ -74,7 +76,8 @@ void Rpc::OnReadsDone() { handler_->OnReadsDone(); } void Rpc::RequestNextMethodInvocation() { // Ask gRPC to notify us when the connection terminates. SetRpcEventState(Event::DONE, true); - server_context_.AsyncNotifyWhenDone(new RpcEvent{Event::DONE, this}); + server_context_.AsyncNotifyWhenDone( + new RpcEvent{Event::DONE, weak_ptr_factory_(this)}); // Make sure after terminating the connection, gRPC notifies us with this // event. @@ -84,19 +87,20 @@ void Rpc::RequestNextMethodInvocation() { service_->RequestAsyncBidiStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, this}); + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: service_->RequestAsyncClientStreaming( method_index_, &server_context_, streaming_interface(), server_completion_queue_, server_completion_queue_, - new RpcEvent{Event::NEW_CONNECTION, this}); + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: service_->RequestAsyncUnary( method_index_, &server_context_, request_.get(), streaming_interface(), server_completion_queue_, - server_completion_queue_, new RpcEvent{Event::NEW_CONNECTION, this}); + server_completion_queue_, + new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); break; default: LOG(FATAL) << "RPC type not implemented."; @@ -109,8 +113,8 @@ void Rpc::RequestStreamingReadIfNeeded() { case ::grpc::internal::RpcMethod::BIDI_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING: SetRpcEventState(Event::READ, true); - async_reader_interface()->Read(request_.get(), - new RpcEvent{Event::READ, this}); + async_reader_interface()->Read( + request_.get(), new RpcEvent{Event::READ, weak_ptr_factory_(this)}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: // For NORMAL_RPC we don't have to do anything here, since gRPC @@ -149,18 +153,19 @@ void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message, switch (rpc_handler_info_.rpc_type) { case ::grpc::internal::RpcMethod::BIDI_STREAMING: CHECK(!message); - server_async_reader_writer_->Finish(status, - new RpcEvent{Event::FINISH, this}); + server_async_reader_writer_->Finish( + status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); break; case ::grpc::internal::RpcMethod::CLIENT_STREAMING: response_ = std::move(message); SendUnaryFinish(server_async_reader_.get(), status, response_.get(), - new RpcEvent{Event::FINISH, this}); + new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); break; case ::grpc::internal::RpcMethod::NORMAL_RPC: response_ = std::move(message); SendUnaryFinish(server_async_response_writer_.get(), status, - response_.get(), new RpcEvent{Event::FINISH, this}); + response_.get(), + new RpcEvent{Event::FINISH, weak_ptr_factory_(this)}); break; default: LOG(FATAL) << "RPC type not implemented."; @@ -198,8 +203,8 @@ void Rpc::PerformWriteIfNeeded() { if (response_) { SetRpcEventState(Event::WRITE, true); - async_writer_interface()->Write(*response_.get(), - new RpcEvent{Event::WRITE, this}); + async_writer_interface()->Write( + *response_.get(), new RpcEvent{Event::WRITE, weak_ptr_factory_(this)}); } else { CHECK(send_queue_.empty()); SendFinish(nullptr /* message */, send_item.status); @@ -314,23 +319,34 @@ ActiveRpcs::~ActiveRpcs() { } } -Rpc* ActiveRpcs::Add(std::unique_ptr rpc) { +std::shared_ptr ActiveRpcs::Add(std::unique_ptr rpc) { cartographer::common::MutexLocker locker(&lock_); - const auto result = rpcs_.emplace(rpc.release()); + std::shared_ptr shared_ptr_rpc = std::move(rpc); + const auto result = rpcs_.emplace(shared_ptr_rpc.get(), shared_ptr_rpc); CHECK(result.second) << "RPC already active."; - return *result.first; + return shared_ptr_rpc; } bool ActiveRpcs::Remove(Rpc* rpc) { cartographer::common::MutexLocker locker(&lock_); auto it = rpcs_.find(rpc); if (it != rpcs_.end()) { - delete rpc; rpcs_.erase(it); return true; } return false; } +Rpc::WeakPtrFactory ActiveRpcs::GetWeakPtrFactory() { + return [this](Rpc* rpc) { return GetWeakPtr(rpc); }; +} + +std::weak_ptr ActiveRpcs::GetWeakPtr(Rpc* rpc) { + cartographer::common::MutexLocker locker(&lock_); + auto it = rpcs_.find(rpc); + CHECK(it != rpcs_.end()); + return it->second; +} + } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 2afe96a..094015e 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -37,15 +37,17 @@ namespace framework { class Service; class Rpc { public: + using WeakPtrFactory = std::function(Rpc*)>; enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE }; struct RpcEvent { const Event event; - Rpc* rpc; + std::weak_ptr rpc; }; Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, ExecutionContext* execution_context, - const RpcHandlerInfo& rpc_handler_info, Service* service); + const RpcHandlerInfo& rpc_handler_info, Service* service, + WeakPtrFactory weak_ptr_factory); std::unique_ptr Clone(); void OnRequest(); void OnReadsDone(); @@ -85,6 +87,7 @@ class Rpc { ExecutionContext* execution_context_; RpcHandlerInfo rpc_handler_info_; Service* service_; + WeakPtrFactory weak_ptr_factory_; ::grpc::ServerContext server_context_; // These state variables indicate whether the corresponding event is currently @@ -122,12 +125,15 @@ class ActiveRpcs { ActiveRpcs(); ~ActiveRpcs() EXCLUDES(lock_); - Rpc* Add(std::unique_ptr rpc) EXCLUDES(lock_); + std::shared_ptr Add(std::unique_ptr rpc) EXCLUDES(lock_); bool Remove(Rpc* rpc) EXCLUDES(lock_); + Rpc::WeakPtrFactory GetWeakPtrFactory(); private: + std::weak_ptr GetWeakPtr(Rpc* rpc); + cartographer::common::Mutex lock_; - std::unordered_set rpcs_; + std::map> rpcs_; }; } // namespace framework diff --git a/cartographer_grpc/framework/server.cc b/cartographer_grpc/framework/server.cc index 8b978e4..6a3e2fa 100644 --- a/cartographer_grpc/framework/server.cc +++ b/cartographer_grpc/framework/server.cc @@ -66,8 +66,11 @@ void Server::RunCompletionQueue( void* tag; while (completion_queue->Next(&tag, &ok)) { auto* rpc_event = static_cast(tag); - rpc_event->rpc->service()->HandleEvent(rpc_event->event, rpc_event->rpc, - ok); + if (auto rpc = rpc_event->rpc.lock()) { + rpc->service()->HandleEvent(rpc_event->event, rpc.get(), ok); + } else { + LOG(WARNING) << "Ignoring stale event."; + } delete rpc_event; } } diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 64ef53e..c501054 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -40,9 +40,10 @@ void Service::StartServing( int i = 0; for (const auto& rpc_handler_info : rpc_handler_infos_) { for (auto& completion_queue_thread : completion_queue_threads) { - Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique( - i, completion_queue_thread.completion_queue(), execution_context, - rpc_handler_info.second, this)); + std::shared_ptr rpc = + active_rpcs_.Add(cartographer::common::make_unique( + i, completion_queue_thread.completion_queue(), execution_context, + rpc_handler_info.second, this, active_rpcs_.GetWeakPtrFactory())); rpc->RequestNextMethodInvocation(); } ++i;