Christoph Schütte 2017-12-18 13:36:44 +01:00 committed by Wally B. Feed
parent c6decd5b7b
commit ba7d375a25
3 changed files with 98 additions and 87 deletions

View File

@ -78,7 +78,7 @@ void Rpc::RequestNextMethodInvocation() {
// Ask gRPC to notify us when the connection terminates. // Ask gRPC to notify us when the connection terminates.
SetRpcEventState(Event::DONE, true); SetRpcEventState(Event::DONE, true);
server_context_.AsyncNotifyWhenDone( server_context_.AsyncNotifyWhenDone(
new RpcEvent{Event::DONE, weak_ptr_factory_(this)}); new RpcEvent{Event::DONE, weak_ptr_factory_(this), true});
// Make sure after terminating the connection, gRPC notifies us with this // Make sure after terminating the connection, gRPC notifies us with this
// event. // event.
@ -88,20 +88,20 @@ void Rpc::RequestNextMethodInvocation() {
service_->RequestAsyncBidiStreaming( service_->RequestAsyncBidiStreaming(
method_index_, &server_context_, streaming_interface(), method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_, server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break; break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
service_->RequestAsyncClientStreaming( service_->RequestAsyncClientStreaming(
method_index_, &server_context_, streaming_interface(), method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_, server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_index_, &server_context_, request_.get(), method_index_, &server_context_, request_.get(),
streaming_interface(), server_completion_queue_, streaming_interface(), server_completion_queue_,
server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)}); new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break; break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
@ -115,7 +115,8 @@ void Rpc::RequestStreamingReadIfNeeded() {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SetRpcEventState(Event::READ, true); SetRpcEventState(Event::READ, true);
async_reader_interface()->Read( async_reader_interface()->Read(
request_.get(), new RpcEvent{Event::READ, weak_ptr_factory_(this)}); request_.get(),
new RpcEvent{Event::READ, weak_ptr_factory_(this), true});
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
// For NORMAL_RPC we don't have to do anything here, since gRPC // For NORMAL_RPC we don't have to do anything here, since gRPC
@ -130,86 +131,37 @@ void Rpc::RequestStreamingReadIfNeeded() {
} }
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
switch (rpc_handler_info_.rpc_type) { EnqueueMessage(SendItem{std::move(message), ::grpc::Status::OK});
case ::grpc::internal::RpcMethod::BIDI_STREAMING: event_queue_->Push(
// For BIDI_STREAMING enqueue the message into the send queue and new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true});
// start write operations if none are currently in flight.
send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SendFinish(std::move(message), ::grpc::Status::OK);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
SendFinish(std::move(message), ::grpc::Status::OK);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
SetRpcEventState(Event::FINISH, true);
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message);
server_async_reader_writer_->Finish(
status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message);
SendUnaryFinish(server_async_reader_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message);
SendUnaryFinish(server_async_response_writer_.get(), status,
response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
} }
void Rpc::Finish(::grpc::Status status) { void Rpc::Finish(::grpc::Status status) {
switch (rpc_handler_info_.rpc_type) { EnqueueMessage(SendItem{nullptr /* message */, status});
case ::grpc::internal::RpcMethod::BIDI_STREAMING: event_queue_->Push(
send_queue_.emplace(SendItem{nullptr /* msg */, status}); new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SendFinish(nullptr /* message */, status);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
SendFinish(nullptr /* message */, status);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
} }
void Rpc::PerformWriteIfNeeded() { void Rpc::HandleSendQueue() {
if (send_queue_.empty() || IsRpcEventPending(Event::WRITE)) { SendItem send_item;
{
cartographer::common::MutexLocker locker(&send_queue_lock_);
if (send_queue_.empty() || IsRpcEventPending(Event::WRITE) ||
IsRpcEventPending(Event::FINISH)) {
return; return;
} }
// Make sure not other send operations are in-flight. send_item = std::move(send_queue_.front());
CHECK(!IsRpcEventPending(Event::FINISH));
SendItem send_item = std::move(send_queue_.front());
send_queue_.pop(); send_queue_.pop();
response_ = std::move(send_item.msg);
if (response_) {
SetRpcEventState(Event::WRITE, true);
async_writer_interface()->Write(
*response_.get(), new RpcEvent{Event::WRITE, weak_ptr_factory_(this)});
} else {
CHECK(send_queue_.empty());
SendFinish(nullptr /* message */, send_item.status);
} }
if (!send_item.msg ||
rpc_handler_info_.rpc_type == ::grpc::internal::RpcMethod::NORMAL_RPC ||
rpc_handler_info_.rpc_type ==
::grpc::internal::RpcMethod::CLIENT_STREAMING) {
PerformFinish(std::move(send_item.msg), send_item.status);
return;
}
PerformWrite(std::move(send_item.msg), send_item.status);
} }
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
@ -259,20 +211,65 @@ Rpc::async_writer_interface() {
bool* Rpc::GetRpcEventState(Event event) { bool* Rpc::GetRpcEventState(Event event) {
switch (event) { switch (event) {
case Event::DONE:
return &done_event_pending_;
case Event::FINISH:
return &finish_event_pending_;
case Event::NEW_CONNECTION: case Event::NEW_CONNECTION:
return &new_connection_event_pending_; return &new_connection_event_pending_;
case Event::READ: case Event::READ:
return &read_event_pending_; return &read_event_pending_;
case Event::WRITE_NEEDED:
return &write_needed_event_pending_;
case Event::WRITE: case Event::WRITE:
return &write_event_pending_; return &write_event_pending_;
case Event::FINISH:
return &finish_event_pending_;
case Event::DONE:
return &done_event_pending_;
} }
LOG(FATAL) << "Never reached."; LOG(FATAL) << "Never reached.";
} }
void Rpc::EnqueueMessage(SendItem&& send_item) {
cartographer::common::MutexLocker locker(&send_queue_lock_);
send_queue_.emplace(std::move(send_item));
}
void Rpc::PerformFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
SetRpcEventState(Event::FINISH, true);
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message);
server_async_reader_writer_->Finish(
status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message);
SendUnaryFinish(
server_async_reader_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message);
SendUnaryFinish(
server_async_response_writer_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::PerformWrite(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
CHECK(message) << "PerformWrite must be called with a non-null message";
CHECK_NE(rpc_handler_info_.rpc_type, ::grpc::internal::RpcMethod::NORMAL_RPC);
CHECK_NE(rpc_handler_info_.rpc_type,
::grpc::internal::RpcMethod::CLIENT_STREAMING);
SetRpcEventState(Event::WRITE, true);
response_ = std::move(message);
async_writer_interface()->Write(
*response_, new RpcEvent{Event::WRITE, weak_ptr_factory_(this), true});
}
void Rpc::SetRpcEventState(Event event, bool pending) { void Rpc::SetRpcEventState(Event event, bool pending) {
*GetRpcEventState(event) = pending; *GetRpcEventState(event) = pending;
} }

View File

@ -36,12 +36,20 @@ namespace cartographer_grpc {
namespace framework { namespace framework {
class Service; class Service;
// TODO(cschuet): Add a unittest that tests the logic of this class.
class Rpc { class Rpc {
public: public:
struct RpcEvent; struct RpcEvent;
using EventQueue = cartographer::common::BlockingQueue<RpcEvent*>; using EventQueue = cartographer::common::BlockingQueue<RpcEvent*>;
using WeakPtrFactory = std::function<std::weak_ptr<Rpc>(Rpc*)>; using WeakPtrFactory = std::function<std::weak_ptr<Rpc>(Rpc*)>;
enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE }; enum class Event {
NEW_CONNECTION = 0,
READ,
WRITE_NEEDED,
WRITE,
FINISH,
DONE
};
struct RpcEvent { struct RpcEvent {
const Event event; const Event event;
std::weak_ptr<Rpc> rpc; std::weak_ptr<Rpc> rpc;
@ -56,7 +64,7 @@ class Rpc {
void OnReadsDone(); void OnReadsDone();
void RequestNextMethodInvocation(); void RequestNextMethodInvocation();
void RequestStreamingReadIfNeeded(); void RequestStreamingReadIfNeeded();
void PerformWriteIfNeeded(); void HandleSendQueue();
void Write(std::unique_ptr<::google::protobuf::Message> message); void Write(std::unique_ptr<::google::protobuf::Message> message);
void Finish(::grpc::Status status); void Finish(::grpc::Status status);
Service* service() { return service_; } Service* service() { return service_; }
@ -76,9 +84,12 @@ class Rpc {
Rpc& operator=(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete;
void InitializeReadersAndWriters( void InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type); ::grpc::internal::RpcMethod::RpcType rpc_type);
void SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
bool* GetRpcEventState(Event event); bool* GetRpcEventState(Event event);
void EnqueueMessage(SendItem&& send_item);
void PerformFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
void PerformWrite(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
async_reader_interface(); async_reader_interface();
@ -102,6 +113,7 @@ class Rpc {
// indicates that the read has completed and currently no read is in-flight. // indicates that the read has completed and currently no read is in-flight.
bool new_connection_event_pending_ = false; bool new_connection_event_pending_ = false;
bool read_event_pending_ = false; bool read_event_pending_ = false;
bool write_needed_event_pending_ = false;
bool write_event_pending_ = false; bool write_event_pending_ = false;
bool finish_event_pending_ = false; bool finish_event_pending_ = false;
bool done_event_pending_ = false; bool done_event_pending_ = false;
@ -120,6 +132,7 @@ class Rpc {
google::protobuf::Message>> google::protobuf::Message>>
server_async_reader_writer_; server_async_reader_writer_;
cartographer::common::Mutex send_queue_lock_;
std::queue<SendItem> send_queue_; std::queue<SendItem> send_queue_;
}; };

View File

@ -66,6 +66,7 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
case Rpc::Event::READ: case Rpc::Event::READ:
HandleRead(rpc, ok); HandleRead(rpc, ok);
break; break;
case Rpc::Event::WRITE_NEEDED:
case Rpc::Event::WRITE: case Rpc::Event::WRITE:
HandleWrite(rpc, ok); HandleWrite(rpc, ok);
break; break;
@ -123,7 +124,7 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
} }
// Send the next message or potentially finish the connection. // Send the next message or potentially finish the connection.
rpc->PerformWriteIfNeeded(); rpc->HandleSendQueue();
RemoveIfNotPending(rpc); RemoveIfNotPending(rpc);
} }