In RpcEvent use std::weak_ptr<Rpc> rather than Rpc* (#757)

[RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md)
master
Christoph Schütte 2017-12-14 16:30:01 +01:00 committed by Wally B. Feed
parent 29e4395a5a
commit e023ec5ecc
4 changed files with 53 additions and 27 deletions

View File

@ -43,12 +43,14 @@ void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status,
Rpc::Rpc(int method_index, Rpc::Rpc(int method_index,
::grpc::ServerCompletionQueue* server_completion_queue, ::grpc::ServerCompletionQueue* server_completion_queue,
ExecutionContext* execution_context, ExecutionContext* execution_context,
const RpcHandlerInfo& rpc_handler_info, Service* service) const RpcHandlerInfo& rpc_handler_info, Service* service,
WeakPtrFactory weak_ptr_factory)
: method_index_(method_index), : method_index_(method_index),
server_completion_queue_(server_completion_queue), server_completion_queue_(server_completion_queue),
execution_context_(execution_context), execution_context_(execution_context),
rpc_handler_info_(rpc_handler_info), rpc_handler_info_(rpc_handler_info),
service_(service), service_(service),
weak_ptr_factory_(weak_ptr_factory),
handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
InitializeReadersAndWriters(rpc_handler_info_.rpc_type); InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
@ -64,7 +66,7 @@ Rpc::Rpc(int method_index,
std::unique_ptr<Rpc> Rpc::Clone() { std::unique_ptr<Rpc> Rpc::Clone() {
return cartographer::common::make_unique<Rpc>( return cartographer::common::make_unique<Rpc>(
method_index_, server_completion_queue_, execution_context_, method_index_, server_completion_queue_, execution_context_,
rpc_handler_info_, service_); rpc_handler_info_, service_, weak_ptr_factory_);
} }
void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); }
@ -74,7 +76,8 @@ void Rpc::OnReadsDone() { handler_->OnReadsDone(); }
void Rpc::RequestNextMethodInvocation() { void Rpc::RequestNextMethodInvocation() {
// Ask gRPC to notify us when the connection terminates. // Ask gRPC to notify us when the connection terminates.
SetRpcEventState(Event::DONE, true); SetRpcEventState(Event::DONE, true);
server_context_.AsyncNotifyWhenDone(new RpcEvent{Event::DONE, this}); server_context_.AsyncNotifyWhenDone(
new RpcEvent{Event::DONE, weak_ptr_factory_(this)});
// Make sure after terminating the connection, gRPC notifies us with this // Make sure after terminating the connection, gRPC notifies us with this
// event. // event.
@ -84,19 +87,20 @@ void Rpc::RequestNextMethodInvocation() {
service_->RequestAsyncBidiStreaming( service_->RequestAsyncBidiStreaming(
method_index_, &server_context_, streaming_interface(), method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_, server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, this}); new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
break; break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
service_->RequestAsyncClientStreaming( service_->RequestAsyncClientStreaming(
method_index_, &server_context_, streaming_interface(), method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_, server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, this}); new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_index_, &server_context_, request_.get(), method_index_, &server_context_, request_.get(),
streaming_interface(), server_completion_queue_, streaming_interface(), server_completion_queue_,
server_completion_queue_, new RpcEvent{Event::NEW_CONNECTION, this}); server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
break; break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
@ -109,8 +113,8 @@ void Rpc::RequestStreamingReadIfNeeded() {
case ::grpc::internal::RpcMethod::BIDI_STREAMING: case ::grpc::internal::RpcMethod::BIDI_STREAMING:
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SetRpcEventState(Event::READ, true); SetRpcEventState(Event::READ, true);
async_reader_interface()->Read(request_.get(), async_reader_interface()->Read(
new RpcEvent{Event::READ, this}); request_.get(), new RpcEvent{Event::READ, weak_ptr_factory_(this)});
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
// For NORMAL_RPC we don't have to do anything here, since gRPC // For NORMAL_RPC we don't have to do anything here, since gRPC
@ -149,18 +153,19 @@ void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message,
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING: case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message); CHECK(!message);
server_async_reader_writer_->Finish(status, server_async_reader_writer_->Finish(
new RpcEvent{Event::FINISH, this}); status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break; break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message); response_ = std::move(message);
SendUnaryFinish(server_async_reader_.get(), status, response_.get(), SendUnaryFinish(server_async_reader_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, this}); new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message); response_ = std::move(message);
SendUnaryFinish(server_async_response_writer_.get(), status, SendUnaryFinish(server_async_response_writer_.get(), status,
response_.get(), new RpcEvent{Event::FINISH, this}); response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break; break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
@ -198,8 +203,8 @@ void Rpc::PerformWriteIfNeeded() {
if (response_) { if (response_) {
SetRpcEventState(Event::WRITE, true); SetRpcEventState(Event::WRITE, true);
async_writer_interface()->Write(*response_.get(), async_writer_interface()->Write(
new RpcEvent{Event::WRITE, this}); *response_.get(), new RpcEvent{Event::WRITE, weak_ptr_factory_(this)});
} else { } else {
CHECK(send_queue_.empty()); CHECK(send_queue_.empty());
SendFinish(nullptr /* message */, send_item.status); SendFinish(nullptr /* message */, send_item.status);
@ -314,23 +319,34 @@ ActiveRpcs::~ActiveRpcs() {
} }
} }
Rpc* ActiveRpcs::Add(std::unique_ptr<Rpc> rpc) { std::shared_ptr<Rpc> ActiveRpcs::Add(std::unique_ptr<Rpc> rpc) {
cartographer::common::MutexLocker locker(&lock_); cartographer::common::MutexLocker locker(&lock_);
const auto result = rpcs_.emplace(rpc.release()); std::shared_ptr<Rpc> shared_ptr_rpc = std::move(rpc);
const auto result = rpcs_.emplace(shared_ptr_rpc.get(), shared_ptr_rpc);
CHECK(result.second) << "RPC already active."; CHECK(result.second) << "RPC already active.";
return *result.first; return shared_ptr_rpc;
} }
bool ActiveRpcs::Remove(Rpc* rpc) { bool ActiveRpcs::Remove(Rpc* rpc) {
cartographer::common::MutexLocker locker(&lock_); cartographer::common::MutexLocker locker(&lock_);
auto it = rpcs_.find(rpc); auto it = rpcs_.find(rpc);
if (it != rpcs_.end()) { if (it != rpcs_.end()) {
delete rpc;
rpcs_.erase(it); rpcs_.erase(it);
return true; return true;
} }
return false; return false;
} }
Rpc::WeakPtrFactory ActiveRpcs::GetWeakPtrFactory() {
return [this](Rpc* rpc) { return GetWeakPtr(rpc); };
}
std::weak_ptr<Rpc> ActiveRpcs::GetWeakPtr(Rpc* rpc) {
cartographer::common::MutexLocker locker(&lock_);
auto it = rpcs_.find(rpc);
CHECK(it != rpcs_.end());
return it->second;
}
} // namespace framework } // namespace framework
} // namespace cartographer_grpc } // namespace cartographer_grpc

View File

@ -37,15 +37,17 @@ namespace framework {
class Service; class Service;
class Rpc { class Rpc {
public: public:
using WeakPtrFactory = std::function<std::weak_ptr<Rpc>(Rpc*)>;
enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE }; enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE };
struct RpcEvent { struct RpcEvent {
const Event event; const Event event;
Rpc* rpc; std::weak_ptr<Rpc> rpc;
}; };
Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue,
ExecutionContext* execution_context, ExecutionContext* execution_context,
const RpcHandlerInfo& rpc_handler_info, Service* service); const RpcHandlerInfo& rpc_handler_info, Service* service,
WeakPtrFactory weak_ptr_factory);
std::unique_ptr<Rpc> Clone(); std::unique_ptr<Rpc> Clone();
void OnRequest(); void OnRequest();
void OnReadsDone(); void OnReadsDone();
@ -85,6 +87,7 @@ class Rpc {
ExecutionContext* execution_context_; ExecutionContext* execution_context_;
RpcHandlerInfo rpc_handler_info_; RpcHandlerInfo rpc_handler_info_;
Service* service_; Service* service_;
WeakPtrFactory weak_ptr_factory_;
::grpc::ServerContext server_context_; ::grpc::ServerContext server_context_;
// These state variables indicate whether the corresponding event is currently // These state variables indicate whether the corresponding event is currently
@ -122,12 +125,15 @@ class ActiveRpcs {
ActiveRpcs(); ActiveRpcs();
~ActiveRpcs() EXCLUDES(lock_); ~ActiveRpcs() EXCLUDES(lock_);
Rpc* Add(std::unique_ptr<Rpc> rpc) EXCLUDES(lock_); std::shared_ptr<Rpc> Add(std::unique_ptr<Rpc> rpc) EXCLUDES(lock_);
bool Remove(Rpc* rpc) EXCLUDES(lock_); bool Remove(Rpc* rpc) EXCLUDES(lock_);
Rpc::WeakPtrFactory GetWeakPtrFactory();
private: private:
std::weak_ptr<Rpc> GetWeakPtr(Rpc* rpc);
cartographer::common::Mutex lock_; cartographer::common::Mutex lock_;
std::unordered_set<Rpc*> rpcs_; std::map<Rpc*, std::shared_ptr<Rpc>> rpcs_;
}; };
} // namespace framework } // namespace framework

View File

@ -66,8 +66,11 @@ void Server::RunCompletionQueue(
void* tag; void* tag;
while (completion_queue->Next(&tag, &ok)) { while (completion_queue->Next(&tag, &ok)) {
auto* rpc_event = static_cast<Rpc::RpcEvent*>(tag); auto* rpc_event = static_cast<Rpc::RpcEvent*>(tag);
rpc_event->rpc->service()->HandleEvent(rpc_event->event, rpc_event->rpc, if (auto rpc = rpc_event->rpc.lock()) {
ok); rpc->service()->HandleEvent(rpc_event->event, rpc.get(), ok);
} else {
LOG(WARNING) << "Ignoring stale event.";
}
delete rpc_event; delete rpc_event;
} }
} }

View File

@ -40,9 +40,10 @@ void Service::StartServing(
int i = 0; int i = 0;
for (const auto& rpc_handler_info : rpc_handler_infos_) { for (const auto& rpc_handler_info : rpc_handler_infos_) {
for (auto& completion_queue_thread : completion_queue_threads) { for (auto& completion_queue_thread : completion_queue_threads) {
Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>( std::shared_ptr<Rpc> rpc =
i, completion_queue_thread.completion_queue(), execution_context, active_rpcs_.Add(cartographer::common::make_unique<Rpc>(
rpc_handler_info.second, this)); i, completion_queue_thread.completion_queue(), execution_context,
rpc_handler_info.second, this, active_rpcs_.GetWeakPtrFactory()));
rpc->RequestNextMethodInvocation(); rpc->RequestNextMethodInvocation();
} }
++i; ++i;