Christoph Schütte 2017-12-18 13:36:44 +01:00 committed by Wally B. Feed
parent c6decd5b7b
commit ba7d375a25
3 changed files with 98 additions and 87 deletions

View File

@ -78,7 +78,7 @@ void Rpc::RequestNextMethodInvocation() {
// Ask gRPC to notify us when the connection terminates.
SetRpcEventState(Event::DONE, true);
server_context_.AsyncNotifyWhenDone(
new RpcEvent{Event::DONE, weak_ptr_factory_(this)});
new RpcEvent{Event::DONE, weak_ptr_factory_(this), true});
// Make sure after terminating the connection, gRPC notifies us with this
// event.
@ -88,20 +88,20 @@ void Rpc::RequestNextMethodInvocation() {
service_->RequestAsyncBidiStreaming(
method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
service_->RequestAsyncClientStreaming(
method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
service_->RequestAsyncUnary(
method_index_, &server_context_, request_.get(),
streaming_interface(), server_completion_queue_,
server_completion_queue_,
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this)});
new RpcEvent{Event::NEW_CONNECTION, weak_ptr_factory_(this), true});
break;
default:
LOG(FATAL) << "RPC type not implemented.";
@ -115,7 +115,8 @@ void Rpc::RequestStreamingReadIfNeeded() {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SetRpcEventState(Event::READ, true);
async_reader_interface()->Read(
request_.get(), new RpcEvent{Event::READ, weak_ptr_factory_(this)});
request_.get(),
new RpcEvent{Event::READ, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
// For NORMAL_RPC we don't have to do anything here, since gRPC
@ -130,86 +131,37 @@ void Rpc::RequestStreamingReadIfNeeded() {
}
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
// For BIDI_STREAMING enqueue the message into the send queue and
// start write operations if none are currently in flight.
send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SendFinish(std::move(message), ::grpc::Status::OK);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
SendFinish(std::move(message), ::grpc::Status::OK);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
SetRpcEventState(Event::FINISH, true);
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message);
server_async_reader_writer_->Finish(
status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message);
SendUnaryFinish(server_async_reader_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message);
SendUnaryFinish(server_async_response_writer_.get(), status,
response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this)});
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
EnqueueMessage(SendItem{std::move(message), ::grpc::Status::OK});
event_queue_->Push(
new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true});
}
void Rpc::Finish(::grpc::Status status) {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
send_queue_.emplace(SendItem{nullptr /* msg */, status});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SendFinish(nullptr /* message */, status);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
SendFinish(nullptr /* message */, status);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
EnqueueMessage(SendItem{nullptr /* message */, status});
event_queue_->Push(
new RpcEvent{Event::WRITE_NEEDED, weak_ptr_factory_(this), true});
}
void Rpc::PerformWriteIfNeeded() {
if (send_queue_.empty() || IsRpcEventPending(Event::WRITE)) {
void Rpc::HandleSendQueue() {
SendItem send_item;
{
cartographer::common::MutexLocker locker(&send_queue_lock_);
if (send_queue_.empty() || IsRpcEventPending(Event::WRITE) ||
IsRpcEventPending(Event::FINISH)) {
return;
}
send_item = std::move(send_queue_.front());
send_queue_.pop();
}
if (!send_item.msg ||
rpc_handler_info_.rpc_type == ::grpc::internal::RpcMethod::NORMAL_RPC ||
rpc_handler_info_.rpc_type ==
::grpc::internal::RpcMethod::CLIENT_STREAMING) {
PerformFinish(std::move(send_item.msg), send_item.status);
return;
}
// Make sure not other send operations are in-flight.
CHECK(!IsRpcEventPending(Event::FINISH));
SendItem send_item = std::move(send_queue_.front());
send_queue_.pop();
response_ = std::move(send_item.msg);
if (response_) {
SetRpcEventState(Event::WRITE, true);
async_writer_interface()->Write(
*response_.get(), new RpcEvent{Event::WRITE, weak_ptr_factory_(this)});
} else {
CHECK(send_queue_.empty());
SendFinish(nullptr /* message */, send_item.status);
}
PerformWrite(std::move(send_item.msg), send_item.status);
}
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
@ -259,20 +211,65 @@ Rpc::async_writer_interface() {
bool* Rpc::GetRpcEventState(Event event) {
switch (event) {
case Event::DONE:
return &done_event_pending_;
case Event::FINISH:
return &finish_event_pending_;
case Event::NEW_CONNECTION:
return &new_connection_event_pending_;
case Event::READ:
return &read_event_pending_;
case Event::WRITE_NEEDED:
return &write_needed_event_pending_;
case Event::WRITE:
return &write_event_pending_;
case Event::FINISH:
return &finish_event_pending_;
case Event::DONE:
return &done_event_pending_;
}
LOG(FATAL) << "Never reached.";
}
void Rpc::EnqueueMessage(SendItem&& send_item) {
cartographer::common::MutexLocker locker(&send_queue_lock_);
send_queue_.emplace(std::move(send_item));
}
void Rpc::PerformFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
SetRpcEventState(Event::FINISH, true);
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message);
server_async_reader_writer_->Finish(
status, new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message);
SendUnaryFinish(
server_async_reader_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message);
SendUnaryFinish(
server_async_response_writer_.get(), status, response_.get(),
new RpcEvent{Event::FINISH, weak_ptr_factory_(this), true});
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::PerformWrite(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
CHECK(message) << "PerformWrite must be called with a non-null message";
CHECK_NE(rpc_handler_info_.rpc_type, ::grpc::internal::RpcMethod::NORMAL_RPC);
CHECK_NE(rpc_handler_info_.rpc_type,
::grpc::internal::RpcMethod::CLIENT_STREAMING);
SetRpcEventState(Event::WRITE, true);
response_ = std::move(message);
async_writer_interface()->Write(
*response_, new RpcEvent{Event::WRITE, weak_ptr_factory_(this), true});
}
void Rpc::SetRpcEventState(Event event, bool pending) {
*GetRpcEventState(event) = pending;
}

View File

@ -36,12 +36,20 @@ namespace cartographer_grpc {
namespace framework {
class Service;
// TODO(cschuet): Add a unittest that tests the logic of this class.
class Rpc {
public:
struct RpcEvent;
using EventQueue = cartographer::common::BlockingQueue<RpcEvent*>;
using WeakPtrFactory = std::function<std::weak_ptr<Rpc>(Rpc*)>;
enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE };
enum class Event {
NEW_CONNECTION = 0,
READ,
WRITE_NEEDED,
WRITE,
FINISH,
DONE
};
struct RpcEvent {
const Event event;
std::weak_ptr<Rpc> rpc;
@ -56,7 +64,7 @@ class Rpc {
void OnReadsDone();
void RequestNextMethodInvocation();
void RequestStreamingReadIfNeeded();
void PerformWriteIfNeeded();
void HandleSendQueue();
void Write(std::unique_ptr<::google::protobuf::Message> message);
void Finish(::grpc::Status status);
Service* service() { return service_; }
@ -76,9 +84,12 @@ class Rpc {
Rpc& operator=(const Rpc&) = delete;
void InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type);
void SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
bool* GetRpcEventState(Event event);
void EnqueueMessage(SendItem&& send_item);
void PerformFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
void PerformWrite(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
async_reader_interface();
@ -102,6 +113,7 @@ class Rpc {
// indicates that the read has completed and currently no read is in-flight.
bool new_connection_event_pending_ = false;
bool read_event_pending_ = false;
bool write_needed_event_pending_ = false;
bool write_event_pending_ = false;
bool finish_event_pending_ = false;
bool done_event_pending_ = false;
@ -120,6 +132,7 @@ class Rpc {
google::protobuf::Message>>
server_async_reader_writer_;
cartographer::common::Mutex send_queue_lock_;
std::queue<SendItem> send_queue_;
};

View File

@ -66,6 +66,7 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
case Rpc::Event::READ:
HandleRead(rpc, ok);
break;
case Rpc::Event::WRITE_NEEDED:
case Rpc::Event::WRITE:
HandleWrite(rpc, ok);
break;
@ -123,7 +124,7 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
}
// Send the next message or potentially finish the connection.
rpc->PerformWriteIfNeeded();
rpc->HandleSendQueue();
RemoveIfNotPending(rpc);
}