Implement bi-directional streaming RPCs. (#720)
[RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md)master
parent
1ff8243802
commit
32a8364b98
|
@ -36,4 +36,5 @@ message GetSquareResponse {
|
|||
service Math {
|
||||
rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
|
||||
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
|
||||
rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,23 @@
|
|||
|
||||
namespace cartographer_grpc {
|
||||
namespace framework {
|
||||
namespace {
|
||||
|
||||
// Finishes the gRPC for non-streaming response RPCs, i.e. NORMAL_RPC and
|
||||
// CLIENT_STREAMING. If no 'msg' is passed, we signal an error to the client as
|
||||
// the server is not honoring the gRPC call signature.
|
||||
template <typename ReaderWriter>
|
||||
void SendUnaryFinish(ReaderWriter* reader_writer, ::grpc::Status status,
|
||||
const google::protobuf::Message* msg,
|
||||
Rpc::RpcEvent* rpc_event) {
|
||||
if (msg) {
|
||||
reader_writer->Finish(*msg, status, rpc_event);
|
||||
} else {
|
||||
reader_writer->FinishWithError(status, rpc_event);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Rpc::Rpc(int method_index,
|
||||
::grpc::ServerCompletionQueue* server_completion_queue,
|
||||
|
@ -60,10 +77,20 @@ void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); }
|
|||
void Rpc::OnReadsDone() { handler_->OnReadsDone(); }
|
||||
|
||||
void Rpc::RequestNextMethodInvocation() {
|
||||
// Ask gRPC to notify us when the connection terminates.
|
||||
done_event_.pending = true;
|
||||
new_connection_event_.pending = true;
|
||||
server_context_.AsyncNotifyWhenDone(&done_event_);
|
||||
|
||||
// Make sure after terminating the connection, gRPC notifies us with this
|
||||
// event.
|
||||
new_connection_event_.pending = true;
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
service_->RequestAsyncBidiStreaming(
|
||||
method_index_, &server_context_, streaming_interface(),
|
||||
server_completion_queue_, server_completion_queue_,
|
||||
&new_connection_event_);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
service_->RequestAsyncClientStreaming(
|
||||
method_index_, &server_context_, streaming_interface(),
|
||||
|
@ -84,6 +111,7 @@ void Rpc::RequestNextMethodInvocation() {
|
|||
void Rpc::RequestStreamingReadIfNeeded() {
|
||||
// For request-streaming RPCs ask the client to start sending requests.
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
read_event_.pending = true;
|
||||
async_reader_interface()->Read(request_.get(), &read_event_);
|
||||
|
@ -101,25 +129,89 @@ void Rpc::RequestStreamingReadIfNeeded() {
|
|||
}
|
||||
|
||||
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
|
||||
response_ = std::move(message);
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
// For BIDI_STREAMING enqueue the message into the send queue and
|
||||
// start write operations if none are currently in flight.
|
||||
send_queue_.emplace(SendItem{std::move(message), ::grpc::Status::OK});
|
||||
PerformWriteIfNeeded();
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
|
||||
&finish_event_);
|
||||
finish_event_.pending = true;
|
||||
SendFinish(std::move(message), ::grpc::Status::OK);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
server_async_response_writer_->Finish(*response_.get(),
|
||||
::grpc::Status::OK, &finish_event_);
|
||||
finish_event_.pending = true;
|
||||
SendFinish(std::move(message), ::grpc::Status::OK);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "RPC type not implemented.";
|
||||
}
|
||||
}
|
||||
|
||||
void Rpc::SendFinish(std::unique_ptr<::google::protobuf::Message> message,
|
||||
::grpc::Status status) {
|
||||
finish_event_.pending = true;
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
CHECK(!message);
|
||||
server_async_reader_writer_->Finish(status, &finish_event_);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
response_ = std::move(message);
|
||||
SendUnaryFinish(server_async_reader_.get(), status, response_.get(),
|
||||
&finish_event_);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
response_ = std::move(message);
|
||||
SendUnaryFinish(server_async_response_writer_.get(), status,
|
||||
response_.get(), &finish_event_);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "RPC type not implemented.";
|
||||
}
|
||||
}
|
||||
|
||||
void Rpc::Finish(::grpc::Status status) {
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
send_queue_.emplace(SendItem{nullptr /* msg */, status});
|
||||
PerformWriteIfNeeded();
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
SendFinish(nullptr /* message */, status);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
SendFinish(nullptr /* message */, status);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "RPC type not implemented.";
|
||||
}
|
||||
}
|
||||
|
||||
void Rpc::PerformWriteIfNeeded() {
|
||||
if (send_queue_.empty() || write_event_.pending) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Make sure not other send operations are in-flight.
|
||||
CHECK(!finish_event_.pending);
|
||||
|
||||
SendItem send_item = std::move(send_queue_.front());
|
||||
send_queue_.pop();
|
||||
response_ = std::move(send_item.msg);
|
||||
|
||||
if (response_) {
|
||||
write_event_.pending = true;
|
||||
async_writer_interface()->Write(*response_.get(), &write_event_);
|
||||
} else {
|
||||
CHECK(send_queue_.empty());
|
||||
SendFinish(nullptr /* message */, send_item.status);
|
||||
}
|
||||
}
|
||||
|
||||
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
return server_async_reader_writer_.get();
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
return server_async_reader_.get();
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
|
@ -133,10 +225,28 @@ void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
|
|||
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
|
||||
Rpc::async_reader_interface() {
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
return server_async_reader_writer_.get();
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
return server_async_reader_.get();
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
LOG(FATAL) << "For NORMAL_RPC no streaming interface exists.";
|
||||
LOG(FATAL) << "For NORMAL_RPC no streaming reader interface exists.";
|
||||
default:
|
||||
LOG(FATAL) << "RPC type not implemented.";
|
||||
}
|
||||
LOG(FATAL) << "Never reached.";
|
||||
}
|
||||
|
||||
::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>*
|
||||
Rpc::async_writer_interface() {
|
||||
switch (rpc_handler_info_.rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
return server_async_reader_writer_.get();
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||
LOG(FATAL) << "For NORMAL_RPC and CLIENT_STREAMING no streaming writer "
|
||||
"interface exists.";
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "RPC type not implemented.";
|
||||
}
|
||||
|
@ -145,16 +255,16 @@ Rpc::async_reader_interface() {
|
|||
|
||||
Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
|
||||
switch (event) {
|
||||
case Event::DONE:
|
||||
return &done_event_;
|
||||
case Event::FINISH:
|
||||
return &finish_event_;
|
||||
case Event::NEW_CONNECTION:
|
||||
return &new_connection_event_;
|
||||
case Event::READ:
|
||||
return &read_event_;
|
||||
case Event::WRITE:
|
||||
return &write_event_;
|
||||
case Event::FINISH:
|
||||
return &finish_event_;
|
||||
case Event::DONE:
|
||||
return &done_event_;
|
||||
}
|
||||
LOG(FATAL) << "Never reached.";
|
||||
}
|
||||
|
@ -164,6 +274,12 @@ ActiveRpcs::ActiveRpcs() : lock_() {}
|
|||
void Rpc::InitializeReadersAndWriters(
|
||||
::grpc::internal::RpcMethod::RpcType rpc_type) {
|
||||
switch (rpc_type) {
|
||||
case ::grpc::internal::RpcMethod::BIDI_STREAMING:
|
||||
server_async_reader_writer_ =
|
||||
cartographer::common::make_unique<::grpc::ServerAsyncReaderWriter<
|
||||
google::protobuf::Message, google::protobuf::Message>>(
|
||||
&server_context_);
|
||||
break;
|
||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||
server_async_reader_ =
|
||||
cartographer::common::make_unique<::grpc::ServerAsyncReader<
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_H
|
||||
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "cartographer/common/mutex.h"
|
||||
|
@ -55,18 +56,30 @@ class Rpc {
|
|||
void OnReadsDone();
|
||||
void RequestNextMethodInvocation();
|
||||
void RequestStreamingReadIfNeeded();
|
||||
void PerformWriteIfNeeded();
|
||||
void Write(std::unique_ptr<::google::protobuf::Message> message);
|
||||
void Finish(::grpc::Status status);
|
||||
Service* service() { return service_; }
|
||||
RpcEvent* GetRpcEvent(Event event);
|
||||
|
||||
private:
|
||||
struct SendItem {
|
||||
std::unique_ptr<google::protobuf::Message> msg;
|
||||
::grpc::Status status;
|
||||
};
|
||||
|
||||
Rpc(const Rpc&) = delete;
|
||||
Rpc& operator=(const Rpc&) = delete;
|
||||
void InitializeReadersAndWriters(
|
||||
::grpc::internal::RpcMethod::RpcType rpc_type);
|
||||
void SendFinish(std::unique_ptr<::google::protobuf::Message> message,
|
||||
::grpc::Status status);
|
||||
|
||||
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
|
||||
async_reader_interface();
|
||||
::grpc::internal::AsyncWriterInterface<::google::protobuf::Message>*
|
||||
async_writer_interface();
|
||||
|
||||
::grpc::internal::ServerAsyncStreamingInterface* streaming_interface();
|
||||
|
||||
int method_index_;
|
||||
|
@ -92,6 +105,11 @@ class Rpc {
|
|||
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
|
||||
google::protobuf::Message>>
|
||||
server_async_reader_;
|
||||
std::unique_ptr<::grpc::ServerAsyncReaderWriter<google::protobuf::Message,
|
||||
google::protobuf::Message>>
|
||||
server_async_reader_writer_;
|
||||
|
||||
std::queue<SendItem> send_queue_;
|
||||
};
|
||||
|
||||
// This class keeps track of all in-flight RPCs for a 'Service'. Make sure that
|
||||
|
|
|
@ -45,6 +45,7 @@ class RpcHandler : public RpcHandlerInterface {
|
|||
OnRequest(static_cast<const RequestType&>(*request));
|
||||
}
|
||||
virtual void OnRequest(const RequestType& request) = 0;
|
||||
void Finish(::grpc::Status status) { rpc_->Finish(status); }
|
||||
void Send(std::unique_ptr<ResponseType> response) {
|
||||
rpc_->Write(std::move(response));
|
||||
}
|
||||
|
|
|
@ -51,6 +51,29 @@ class GetSumHandler
|
|||
int sum_ = 0;
|
||||
};
|
||||
|
||||
class GetRunningSumHandler
|
||||
: public RpcHandler<Stream<proto::GetSumRequest>, Stream<proto::GetSumResponse>> {
|
||||
public:
|
||||
void OnRequest(const proto::GetSumRequest& request) override {
|
||||
sum_ += request.input();
|
||||
|
||||
// Respond twice to demonstrate bidirectional streaming.
|
||||
auto response = cartographer::common::make_unique<proto::GetSumResponse>();
|
||||
response->set_output(sum_);
|
||||
Send(std::move(response));
|
||||
response = cartographer::common::make_unique<proto::GetSumResponse>();
|
||||
response->set_output(sum_);
|
||||
Send(std::move(response));
|
||||
}
|
||||
|
||||
void OnReadsDone() override {
|
||||
Finish(::grpc::Status::OK);
|
||||
}
|
||||
|
||||
private:
|
||||
int sum_ = 0;
|
||||
};
|
||||
|
||||
class GetSquareHandler
|
||||
: public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> {
|
||||
void OnRequest(const proto::GetSquareRequest& request) override {
|
||||
|
@ -75,6 +98,7 @@ class ServerTest : public ::testing::Test {
|
|||
server_builder.SetNumberOfThreads(kNumThreads);
|
||||
server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum");
|
||||
server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
|
||||
server_builder.RegisterHandler<GetRunningSumHandler, proto::Math>("GetRunningSum");
|
||||
server_ = server_builder.Build();
|
||||
|
||||
client_channel_ =
|
||||
|
@ -127,6 +151,28 @@ TEST_F(ServerTest, ProcessUnaryRpcTest) {
|
|||
server_->Shutdown();
|
||||
}
|
||||
|
||||
TEST_F(ServerTest, ProcessBidiStreamingRpcTest) {
|
||||
server_->Start();
|
||||
|
||||
auto reader_writer = stub_->GetRunningSum(&client_context_);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
proto::GetSumRequest request;
|
||||
request.set_input(i);
|
||||
EXPECT_TRUE(reader_writer->Write(request));
|
||||
}
|
||||
reader_writer->WritesDone();
|
||||
proto::GetSumResponse response;
|
||||
|
||||
std::list<int> expected_responses = {0, 0, 1, 1, 3, 3};
|
||||
while (reader_writer->Read(&response)) {
|
||||
EXPECT_EQ(expected_responses.front(), response.output());
|
||||
expected_responses.pop_front();
|
||||
}
|
||||
EXPECT_TRUE(expected_responses.empty());
|
||||
|
||||
server_->Shutdown();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace framework
|
||||
} // namespace cartographer_grpc
|
||||
|
|
|
@ -105,6 +105,8 @@ void Service::HandleRead(Rpc* rpc, bool ok) {
|
|||
|
||||
// Reads completed.
|
||||
rpc->OnReadsDone();
|
||||
|
||||
RemoveIfNotPending(rpc);
|
||||
}
|
||||
|
||||
void Service::HandleWrite(Rpc* rpc, bool ok) {
|
||||
|
@ -112,6 +114,9 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
|
|||
LOG(ERROR) << "Write failed";
|
||||
}
|
||||
|
||||
// Send the next message or potentially finish the connection.
|
||||
rpc->PerformWriteIfNeeded();
|
||||
|
||||
RemoveIfNotPending(rpc);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue