Christoph Schütte 2017-12-04 15:28:19 +01:00 committed by GitHub
parent 1ff8243802
commit 32a8364b98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 200 additions and 13 deletions

View File

@ -36,4 +36,5 @@ message GetSquareResponse {
service Math { service Math {
rpc GetSum(stream GetSumRequest) returns (GetSumResponse); rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse); rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse);
} }

View File

@ -22,6 +22,23 @@
namespace cartographer_grpc { namespace cartographer_grpc {
namespace framework { namespace framework {
namespace {
// Finishes the gRPC for non-streaming response RPCs, i.e. NORMAL_RPC and
// CLIENT_STREAMING. If no 'msg' is passed, we signal an error to the client as
// the server is not honoring the gRPC call signature.
template <typename ReaderWriter>
void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status,
const google::protobuf::Message* msg,
Rpc::RpcEvent* rpc_event) {
if (msg) {
reader_writer->Finish(*msg, status, rpc_event);
} else {
reader_writer->FinishWithError(status, rpc_event);
}
}
} // namespace
Rpc::Rpc(int method_index, Rpc::Rpc(int method_index,
::grpc::ServerCompletionQueue* server_completion_queue, ::grpc::ServerCompletionQueue* server_completion_queue,
@ -60,10 +77,20 @@ void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); }
void Rpc::OnReadsDone() { handler_->OnReadsDone(); } void Rpc::OnReadsDone() { handler_->OnReadsDone(); }
void Rpc::RequestNextMethodInvocation() { void Rpc::RequestNextMethodInvocation() {
// Ask gRPC to notify us when the connection terminates.
done_event_.pending = true; done_event_.pending = true;
new_connection_event_.pending = true;
server_context_.AsyncNotifyWhenDone(&done_event_); server_context_.AsyncNotifyWhenDone(&done_event_);
// Make sure after terminating the connection, gRPC notifies us with this
// event.
new_connection_event_.pending = true;
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
service_->RequestAsyncBidiStreaming(
method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_,
&new_connection_event_);
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
service_->RequestAsyncClientStreaming( service_->RequestAsyncClientStreaming(
method_index_, &server_context_, streaming_interface(), method_index_, &server_context_, streaming_interface(),
@ -84,6 +111,7 @@ void Rpc::RequestNextMethodInvocation() {
void Rpc::RequestStreamingReadIfNeeded() { void Rpc::RequestStreamingReadIfNeeded() {
// For request-streaming RPCs ask the client to start sending requests. // For request-streaming RPCs ask the client to start sending requests.
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
read_event_.pending = true; read_event_.pending = true;
async_reader_interface()->Read(request_.get(), &read_event_); async_reader_interface()->Read(request_.get(), &read_event_);
@ -101,25 +129,89 @@ void Rpc::RequestStreamingReadIfNeeded() {
} }
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
response_ = std::move(message);
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
// For BIDI_STREAMING enqueue the message into the send queue and
// start write operations if none are currently in flight.
send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, SendFinish(std::move(message), ::grpc::Status::OK);
&finish_event_);
finish_event_.pending = true;
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
server_async_response_writer_->Finish(*response_.get(), SendFinish(std::move(message), ::grpc::Status::OK);
::grpc::Status::OK, &finish_event_);
finish_event_.pending = true;
break; break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
} }
void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status) {
finish_event_.pending = true;
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
CHECK(!message);
server_async_reader_writer_->Finish(status, &finish_event_);
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
response_ = std::move(message);
SendUnaryFinish(server_async_reader_.get(), status, response_.get(),
&finish_event_);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
response_ = std::move(message);
SendUnaryFinish(server_async_response_writer_.get(), status,
response_.get(), &finish_event_);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::Finish(::grpc::Status status) {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
send_queue_.emplace(SendItem{nullptr /* msg */, status});
PerformWriteIfNeeded();
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
SendFinish(nullptr /* message */, status);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
SendFinish(nullptr /* message */, status);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::PerformWriteIfNeeded() {
if (send_queue_.empty() || write_event_.pending) {
return;
}
// Make sure not other send operations are in-flight.
CHECK(!finish_event_.pending);
SendItem send_item = std::move(send_queue_.front());
send_queue_.pop();
response_ = std::move(send_item.msg);
if (response_) {
write_event_.pending = true;
async_writer_interface()->Write(*response_.get(), &write_event_);
} else {
CHECK(send_queue_.empty());
SendFinish(nullptr /* message */, send_item.status);
}
}
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
return server_async_reader_writer_.get();
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get(); return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
@ -133,10 +225,28 @@ void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
Rpc::async_reader_interface() { Rpc::async_reader_interface() {
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
return server_async_reader_writer_.get();
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get(); return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC: case ::grpc::internal::RpcMethod::NORMAL_RPC:
LOG(FATAL) << "For NORMAL_RPC no streaming interface exists."; LOG(FATAL) << "For NORMAL_RPC no streaming reader interface exists.";
default:
LOG(FATAL) << "RPC type not implemented.";
}
LOG(FATAL) << "Never reached.";
}
::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>*
Rpc::async_writer_interface() {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
return server_async_reader_writer_.get();
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
case ::grpc::internal::RpcMethod::NORMAL_RPC:
LOG(FATAL) << "For NORMAL_RPC and CLIENT_STREAMING no streaming writer "
"interface exists.";
break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
@ -145,16 +255,16 @@ Rpc::async_reader_interface() {
Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) { Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
switch (event) { switch (event) {
case Event::DONE:
return &done_event_;
case Event::FINISH:
return &finish_event_;
case Event::NEW_CONNECTION: case Event::NEW_CONNECTION:
return &new_connection_event_; return &new_connection_event_;
case Event::READ: case Event::READ:
return &read_event_; return &read_event_;
case Event::WRITE: case Event::WRITE:
return &write_event_; return &write_event_;
case Event::FINISH:
return &finish_event_;
case Event::DONE:
return &done_event_;
} }
LOG(FATAL) << "Never reached."; LOG(FATAL) << "Never reached.";
} }
@ -164,6 +274,12 @@ ActiveRpcs::ActiveRpcs() : lock_() {}
void Rpc::InitializeReadersAndWriters( void Rpc::InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type) { ::grpc::internal::RpcMethod::RpcType rpc_type) {
switch (rpc_type) { switch (rpc_type) {
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
server_async_reader_writer_ =
cartographer::common::make_unique<::grpc::ServerAsyncReaderWriter<
google::protobuf::Message, google::protobuf::Message>>(
&server_context_);
break;
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
server_async_reader_ = server_async_reader_ =
cartographer::common::make_unique<::grpc::ServerAsyncReader< cartographer::common::make_unique<::grpc::ServerAsyncReader<

View File

@ -18,6 +18,7 @@
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_H
#include <memory> #include <memory>
#include <queue>
#include <unordered_set> #include <unordered_set>
#include "cartographer/common/mutex.h" #include "cartographer/common/mutex.h"
@ -55,18 +56,30 @@ class Rpc {
void OnReadsDone(); void OnReadsDone();
void RequestNextMethodInvocation(); void RequestNextMethodInvocation();
void RequestStreamingReadIfNeeded(); void RequestStreamingReadIfNeeded();
void PerformWriteIfNeeded();
void Write(std::unique_ptr<::google::protobuf::Message> message); void Write(std::unique_ptr<::google::protobuf::Message> message);
void Finish(::grpc::Status status);
Service* service() { return service_; } Service* service() { return service_; }
RpcEvent* GetRpcEvent(Event event); RpcEvent* GetRpcEvent(Event event);
private: private:
struct SendItem {
std::unique_ptr<google::protobuf::Message> msg;
::grpc::Status status;
};
Rpc(const Rpc&) = delete; Rpc(const Rpc&) = delete;
Rpc& operator=(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete;
void InitializeReadersAndWriters( void InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type); ::grpc::internal::RpcMethod::RpcType rpc_type);
void SendFinish(std::unique_ptr<::google::protobuf::Message> message,
::grpc::Status status);
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>* ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
async_reader_interface(); async_reader_interface();
::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>*
async_writer_interface();
::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); ::grpc::internal::ServerAsyncStreamingInterface* streaming_interface();
int method_index_; int method_index_;
@ -92,6 +105,11 @@ class Rpc {
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message, std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
google::protobuf::Message>> google::protobuf::Message>>
server_async_reader_; server_async_reader_;
std::unique_ptr<::grpc::ServerAsyncReaderWriter<google::protobuf::Message,
google::protobuf::Message>>
server_async_reader_writer_;
std::queue<SendItem> send_queue_;
}; };
// This class keeps track of all in-flight RPCs for a 'Service'. Make sure that // This class keeps track of all in-flight RPCs for a 'Service'. Make sure that

View File

@ -45,6 +45,7 @@ class RpcHandler : public RpcHandlerInterface {
OnRequest(static_cast<const RequestType&>(*request)); OnRequest(static_cast<const RequestType&>(*request));
} }
virtual void OnRequest(const RequestType& request) = 0; virtual void OnRequest(const RequestType& request) = 0;
void Finish(::grpc::Status status) { rpc_->Finish(status); }
void Send(std::unique_ptr<ResponseType> response) { void Send(std::unique_ptr<ResponseType> response) {
rpc_->Write(std::move(response)); rpc_->Write(std::move(response));
} }

View File

@ -51,6 +51,29 @@ class GetSumHandler
int sum_ = 0; int sum_ = 0;
}; };
class GetRunningSumHandler
: public RpcHandler<Stream<proto::GetSumRequest>, Stream<proto::GetSumResponse>> {
public:
void OnRequest(const proto::GetSumRequest& request) override {
sum_ += request.input();
// Respond twice to demonstrate bidirectional streaming.
auto response = cartographer::common::make_unique<proto::GetSumResponse>();
response->set_output(sum_);
Send(std::move(response));
response = cartographer::common::make_unique<proto::GetSumResponse>();
response->set_output(sum_);
Send(std::move(response));
}
void OnReadsDone() override {
Finish(::grpc::Status::OK);
}
private:
int sum_ = 0;
};
class GetSquareHandler class GetSquareHandler
: public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> { : public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> {
void OnRequest(const proto::GetSquareRequest& request) override { void OnRequest(const proto::GetSquareRequest& request) override {
@ -75,6 +98,7 @@ class ServerTest : public ::testing::Test {
server_builder.SetNumberOfThreads(kNumThreads); server_builder.SetNumberOfThreads(kNumThreads);
server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum"); server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum");
server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare"); server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
server_builder.RegisterHandler<GetRunningSumHandler, proto::Math>("GetRunningSum");
server_ = server_builder.Build(); server_ = server_builder.Build();
client_channel_ = client_channel_ =
@ -127,6 +151,28 @@ TEST_F(ServerTest, ProcessUnaryRpcTest) {
server_->Shutdown(); server_->Shutdown();
} }
TEST_F(ServerTest, ProcessBidiStreamingRpcTest) {
server_->Start();
auto reader_writer = stub_->GetRunningSum(&client_context_);
for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request;
request.set_input(i);
EXPECT_TRUE(reader_writer->Write(request));
}
reader_writer->WritesDone();
proto::GetSumResponse response;
std::list<int> expected_responses = {0, 0, 1, 1, 3, 3};
while (reader_writer->Read(&response)) {
EXPECT_EQ(expected_responses.front(), response.output());
expected_responses.pop_front();
}
EXPECT_TRUE(expected_responses.empty());
server_->Shutdown();
}
} // namespace } // namespace
} // namespace framework } // namespace framework
} // namespace cartographer_grpc } // namespace cartographer_grpc

View File

@ -105,6 +105,8 @@ void Service::HandleRead(Rpc* rpc, bool ok) {
// Reads completed. // Reads completed.
rpc->OnReadsDone(); rpc->OnReadsDone();
RemoveIfNotPending(rpc);
} }
void Service::HandleWrite(Rpc* rpc, bool ok) { void Service::HandleWrite(Rpc* rpc, bool ok) {
@ -112,6 +114,9 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
LOG(ERROR) << "Write failed"; LOG(ERROR) << "Write failed";
} }
// Send the next message or potentially finish the connection.
rpc->PerformWriteIfNeeded();
RemoveIfNotPending(rpc); RemoveIfNotPending(rpc);
} }