Implement unary gRPC calls. (#719)
[RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md)master
parent
24f253a2aa
commit
5147af9763
|
@ -24,7 +24,16 @@ message GetSumResponse {
|
||||||
int32 output = 1;
|
int32 output = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message GetSquareRequest {
|
||||||
|
int32 input = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetSquareResponse {
|
||||||
|
int32 output = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Provides information about the gRPC server.
|
// Provides information about the gRPC server.
|
||||||
service Math {
|
service Math {
|
||||||
rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
|
rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
|
||||||
|
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ Rpc::Rpc(int method_index,
|
||||||
new_connection_event_{Event::NEW_CONNECTION, this, false},
|
new_connection_event_{Event::NEW_CONNECTION, this, false},
|
||||||
read_event_{Event::READ, this, false},
|
read_event_{Event::READ, this, false},
|
||||||
write_event_{Event::WRITE, this, false},
|
write_event_{Event::WRITE, this, false},
|
||||||
|
finish_event_{Event::FINISH, this, false},
|
||||||
done_event_{Event::DONE, this, false},
|
done_event_{Event::DONE, this, false},
|
||||||
handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
|
handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
|
||||||
InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
|
InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
|
||||||
|
@ -69,6 +70,12 @@ void Rpc::RequestNextMethodInvocation() {
|
||||||
server_completion_queue_, server_completion_queue_,
|
server_completion_queue_, server_completion_queue_,
|
||||||
&new_connection_event_);
|
&new_connection_event_);
|
||||||
break;
|
break;
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
service_->RequestAsyncUnary(
|
||||||
|
method_index_, &server_context_, request_.get(),
|
||||||
|
streaming_interface(), server_completion_queue_,
|
||||||
|
server_completion_queue_, &new_connection_event_);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "RPC type not implemented.";
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
}
|
}
|
||||||
|
@ -81,6 +88,13 @@ void Rpc::RequestStreamingReadIfNeeded() {
|
||||||
read_event_.pending = true;
|
read_event_.pending = true;
|
||||||
async_reader_interface()->Read(request_.get(), &read_event_);
|
async_reader_interface()->Read(request_.get(), &read_event_);
|
||||||
break;
|
break;
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
// For NORMAL_RPC we don't have to do anything here, since gRPC
|
||||||
|
// automatically issues a READ request and places the request into the
|
||||||
|
// 'Message' we provided to 'RequestAsyncUnary' above.
|
||||||
|
OnRequest();
|
||||||
|
OnReadsDone();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "RPC type not implemented.";
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
}
|
}
|
||||||
|
@ -88,15 +102,28 @@ void Rpc::RequestStreamingReadIfNeeded() {
|
||||||
|
|
||||||
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
|
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
|
||||||
response_ = std::move(message);
|
response_ = std::move(message);
|
||||||
write_event_.pending = true;
|
switch (rpc_handler_info_.rpc_type) {
|
||||||
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
|
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||||
&write_event_);
|
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
|
||||||
|
&finish_event_);
|
||||||
|
finish_event_.pending = true;
|
||||||
|
break;
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
server_async_response_writer_->Finish(*response_.get(),
|
||||||
|
::grpc::Status::OK, &finish_event_);
|
||||||
|
finish_event_.pending = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
|
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
|
||||||
switch (rpc_handler_info_.rpc_type) {
|
switch (rpc_handler_info_.rpc_type) {
|
||||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||||
return server_async_reader_.get();
|
return server_async_reader_.get();
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
return server_async_response_writer_.get();
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "RPC type not implemented.";
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
}
|
}
|
||||||
|
@ -108,6 +135,8 @@ Rpc::async_reader_interface() {
|
||||||
switch (rpc_handler_info_.rpc_type) {
|
switch (rpc_handler_info_.rpc_type) {
|
||||||
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
|
||||||
return server_async_reader_.get();
|
return server_async_reader_.get();
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
LOG(FATAL) << "For NORMAL_RPC no streaming interface exists.";
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "RPC type not implemented.";
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
}
|
}
|
||||||
|
@ -122,6 +151,8 @@ Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
|
||||||
return &read_event_;
|
return &read_event_;
|
||||||
case Event::WRITE:
|
case Event::WRITE:
|
||||||
return &write_event_;
|
return &write_event_;
|
||||||
|
case Event::FINISH:
|
||||||
|
return &finish_event_;
|
||||||
case Event::DONE:
|
case Event::DONE:
|
||||||
return &done_event_;
|
return &done_event_;
|
||||||
}
|
}
|
||||||
|
@ -139,6 +170,11 @@ void Rpc::InitializeReadersAndWriters(
|
||||||
google::protobuf::Message, google::protobuf::Message>>(
|
google::protobuf::Message, google::protobuf::Message>>(
|
||||||
&server_context_);
|
&server_context_);
|
||||||
break;
|
break;
|
||||||
|
case ::grpc::internal::RpcMethod::NORMAL_RPC:
|
||||||
|
server_async_response_writer_ = cartographer::common::make_unique<
|
||||||
|
::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>(
|
||||||
|
&server_context_);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "RPC type not implemented.";
|
LOG(FATAL) << "RPC type not implemented.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ namespace framework {
|
||||||
class Service;
|
class Service;
|
||||||
class Rpc {
|
class Rpc {
|
||||||
public:
|
public:
|
||||||
enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE };
|
enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE };
|
||||||
struct RpcEvent {
|
struct RpcEvent {
|
||||||
const Event event;
|
const Event event;
|
||||||
Rpc* rpc;
|
Rpc* rpc;
|
||||||
|
@ -79,6 +79,7 @@ class Rpc {
|
||||||
RpcEvent new_connection_event_;
|
RpcEvent new_connection_event_;
|
||||||
RpcEvent read_event_;
|
RpcEvent read_event_;
|
||||||
RpcEvent write_event_;
|
RpcEvent write_event_;
|
||||||
|
RpcEvent finish_event_;
|
||||||
RpcEvent done_event_;
|
RpcEvent done_event_;
|
||||||
|
|
||||||
std::unique_ptr<google::protobuf::Message> request_;
|
std::unique_ptr<google::protobuf::Message> request_;
|
||||||
|
@ -86,6 +87,8 @@ class Rpc {
|
||||||
|
|
||||||
std::unique_ptr<RpcHandlerInterface> handler_;
|
std::unique_ptr<RpcHandlerInterface> handler_;
|
||||||
|
|
||||||
|
std::unique_ptr<::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>
|
||||||
|
server_async_response_writer_;
|
||||||
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
|
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
|
||||||
google::protobuf::Message>>
|
google::protobuf::Message>>
|
||||||
server_async_reader_;
|
server_async_reader_;
|
||||||
|
|
|
@ -32,7 +32,7 @@ class RpcHandlerInterface {
|
||||||
virtual void SetRpc(Rpc* rpc) = 0;
|
virtual void SetRpc(Rpc* rpc) = 0;
|
||||||
virtual void OnRequestInternal(
|
virtual void OnRequestInternal(
|
||||||
const ::google::protobuf::Message* request) = 0;
|
const ::google::protobuf::Message* request) = 0;
|
||||||
virtual void OnReadsDone() = 0;
|
virtual void OnReadsDone(){};
|
||||||
};
|
};
|
||||||
|
|
||||||
using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(
|
using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(
|
||||||
|
@ -43,6 +43,7 @@ struct RpcHandlerInfo {
|
||||||
const google::protobuf::Descriptor* response_descriptor;
|
const google::protobuf::Descriptor* response_descriptor;
|
||||||
const RpcHandlerFactory rpc_handler_factory;
|
const RpcHandlerFactory rpc_handler_factory;
|
||||||
const grpc::internal::RpcMethod::RpcType rpc_type;
|
const grpc::internal::RpcMethod::RpcType rpc_type;
|
||||||
|
const std::string fully_qualified_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace framework
|
} // namespace framework
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
|
@ -53,6 +54,9 @@ class Server {
|
||||||
|
|
||||||
template <typename RpcHandlerType, typename ServiceType>
|
template <typename RpcHandlerType, typename ServiceType>
|
||||||
void RegisterHandler(const std::string& method_name) {
|
void RegisterHandler(const std::string& method_name) {
|
||||||
|
std::stringstream fully_qualified_name;
|
||||||
|
fully_qualified_name << "/" << ServiceType::service_full_name() << "/"
|
||||||
|
<< method_name;
|
||||||
rpc_handlers_[ServiceType::service_full_name()].emplace(
|
rpc_handlers_[ServiceType::service_full_name()].emplace(
|
||||||
method_name,
|
method_name,
|
||||||
RpcHandlerInfo{
|
RpcHandlerInfo{
|
||||||
|
@ -66,7 +70,8 @@ class Server {
|
||||||
return rpc_handler;
|
return rpc_handler;
|
||||||
},
|
},
|
||||||
RpcType<typename RpcHandlerType::IncomingType,
|
RpcType<typename RpcHandlerType::IncomingType,
|
||||||
typename RpcHandlerType::OutgoingType>::value});
|
typename RpcHandlerType::OutgoingType>::value,
|
||||||
|
fully_qualified_name.str()});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -33,7 +33,7 @@ class MathServerContext : public ExecutionContext {
|
||||||
int additional_increment() { return 10; }
|
int additional_increment() { return 10; }
|
||||||
};
|
};
|
||||||
|
|
||||||
class GetServerOptionsHandler
|
class GetSumHandler
|
||||||
: public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
|
: public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
|
||||||
public:
|
public:
|
||||||
void OnRequest(const proto::GetSumRequest& request) override {
|
void OnRequest(const proto::GetSumRequest& request) override {
|
||||||
|
@ -51,6 +51,16 @@ class GetServerOptionsHandler
|
||||||
int sum_ = 0;
|
int sum_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GetSquareHandler
|
||||||
|
: public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> {
|
||||||
|
void OnRequest(const proto::GetSquareRequest& request) override {
|
||||||
|
auto response =
|
||||||
|
cartographer::common::make_unique<proto::GetSquareResponse>();
|
||||||
|
response->set_output(request.input() * request.input());
|
||||||
|
Send(std::move(response));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when
|
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when
|
||||||
// run in parallel. It would be nice to find a way to solve that. gRPC also
|
// run in parallel. It would be nice to find a way to solve that. gRPC also
|
||||||
// allows to communicate over UNIX domain sockets.
|
// allows to communicate over UNIX domain sockets.
|
||||||
|
@ -63,12 +73,19 @@ class ServerTest : public ::testing::Test {
|
||||||
Server::Builder server_builder;
|
Server::Builder server_builder;
|
||||||
server_builder.SetServerAddress(kServerAddress);
|
server_builder.SetServerAddress(kServerAddress);
|
||||||
server_builder.SetNumberOfThreads(kNumThreads);
|
server_builder.SetNumberOfThreads(kNumThreads);
|
||||||
server_builder.RegisterHandler<GetServerOptionsHandler, proto::Math>(
|
server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum");
|
||||||
"GetSum");
|
server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
|
||||||
server_ = server_builder.Build();
|
server_ = server_builder.Build();
|
||||||
|
|
||||||
|
client_channel_ =
|
||||||
|
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
|
||||||
|
stub_ = proto::Math::NewStub(client_channel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Server> server_;
|
std::unique_ptr<Server> server_;
|
||||||
|
std::shared_ptr<grpc::Channel> client_channel_;
|
||||||
|
std::unique_ptr<proto::Math::Stub> stub_;
|
||||||
|
grpc::ClientContext client_context_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ServerTest, StartAndStopServerTest) {
|
TEST_F(ServerTest, StartAndStopServerTest) {
|
||||||
|
@ -81,13 +98,9 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
|
||||||
cartographer::common::make_unique<MathServerContext>());
|
cartographer::common::make_unique<MathServerContext>());
|
||||||
server_->Start();
|
server_->Start();
|
||||||
|
|
||||||
auto channel =
|
|
||||||
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
|
|
||||||
std::unique_ptr<proto::Math::Stub> stub(proto::Math::NewStub(channel));
|
|
||||||
grpc::ClientContext context;
|
|
||||||
proto::GetSumResponse result;
|
proto::GetSumResponse result;
|
||||||
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer(
|
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer(
|
||||||
stub->GetSum(&context, &result));
|
stub_->GetSum(&client_context_, &result));
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
proto::GetSumRequest request;
|
proto::GetSumRequest request;
|
||||||
request.set_input(i);
|
request.set_input(i);
|
||||||
|
@ -101,6 +114,19 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
|
||||||
server_->Shutdown();
|
server_->Shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ServerTest, ProcessUnaryRpcTest) {
|
||||||
|
server_->Start();
|
||||||
|
|
||||||
|
proto::GetSquareResponse result;
|
||||||
|
proto::GetSquareRequest request;
|
||||||
|
request.set_input(11);
|
||||||
|
grpc::Status status = stub_->GetSquare(&client_context_, request, &result);
|
||||||
|
EXPECT_TRUE(status.ok());
|
||||||
|
EXPECT_EQ(result.output(), 121);
|
||||||
|
|
||||||
|
server_->Shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace framework
|
} // namespace framework
|
||||||
} // namespace cartographer_grpc
|
} // namespace cartographer_grpc
|
||||||
|
|
|
@ -26,13 +26,11 @@ Service::Service(const std::string& service_name,
|
||||||
const std::map<std::string, RpcHandlerInfo>& rpc_handler_infos)
|
const std::map<std::string, RpcHandlerInfo>& rpc_handler_infos)
|
||||||
: rpc_handler_infos_(rpc_handler_infos) {
|
: rpc_handler_infos_(rpc_handler_infos) {
|
||||||
for (const auto& rpc_handler_info : rpc_handler_infos_) {
|
for (const auto& rpc_handler_info : rpc_handler_infos_) {
|
||||||
std::string fully_qualified_method_name =
|
|
||||||
"/" + service_name + "/" + rpc_handler_info.first;
|
|
||||||
// The 'handler' below is set to 'nullptr' indicating that we want to
|
// The 'handler' below is set to 'nullptr' indicating that we want to
|
||||||
// handle this method asynchronously.
|
// handle this method asynchronously.
|
||||||
this->AddMethod(new grpc::internal::RpcServiceMethod(
|
this->AddMethod(new grpc::internal::RpcServiceMethod(
|
||||||
fully_qualified_method_name.c_str(), rpc_handler_info.second.rpc_type,
|
rpc_handler_info.second.fully_qualified_name.c_str(),
|
||||||
nullptr /* handler */));
|
rpc_handler_info.second.rpc_type, nullptr /* handler */));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,6 +63,9 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
|
||||||
case Rpc::Event::WRITE:
|
case Rpc::Event::WRITE:
|
||||||
HandleWrite(rpc, ok);
|
HandleWrite(rpc, ok);
|
||||||
break;
|
break;
|
||||||
|
case Rpc::Event::FINISH:
|
||||||
|
HandleFinish(rpc, ok);
|
||||||
|
break;
|
||||||
case Rpc::Event::DONE:
|
case Rpc::Event::DONE:
|
||||||
HandleDone(rpc, ok);
|
HandleDone(rpc, ok);
|
||||||
break;
|
break;
|
||||||
|
@ -114,12 +115,21 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
|
||||||
RemoveIfNotPending(rpc);
|
RemoveIfNotPending(rpc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Service::HandleFinish(Rpc* rpc, bool ok) {
|
||||||
|
if (!ok) {
|
||||||
|
LOG(ERROR) << "Finish failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
RemoveIfNotPending(rpc);
|
||||||
|
}
|
||||||
|
|
||||||
void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }
|
void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }
|
||||||
|
|
||||||
void Service::RemoveIfNotPending(Rpc* rpc) {
|
void Service::RemoveIfNotPending(Rpc* rpc) {
|
||||||
if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending &&
|
if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending &&
|
||||||
!rpc->GetRpcEvent(Rpc::Event::READ)->pending &&
|
!rpc->GetRpcEvent(Rpc::Event::READ)->pending &&
|
||||||
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) {
|
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending &&
|
||||||
|
!rpc->GetRpcEvent(Rpc::Event::FINISH)->pending) {
|
||||||
active_rpcs_.Remove(rpc);
|
active_rpcs_.Remove(rpc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ class Service : public ::grpc::Service {
|
||||||
void HandleNewConnection(Rpc* rpc, bool ok);
|
void HandleNewConnection(Rpc* rpc, bool ok);
|
||||||
void HandleRead(Rpc* rpc, bool ok);
|
void HandleRead(Rpc* rpc, bool ok);
|
||||||
void HandleWrite(Rpc* rpc, bool ok);
|
void HandleWrite(Rpc* rpc, bool ok);
|
||||||
|
void HandleFinish(Rpc* rpc, bool ok);
|
||||||
void HandleDone(Rpc* rpc, bool ok);
|
void HandleDone(Rpc* rpc, bool ok);
|
||||||
|
|
||||||
void RemoveIfNotPending(Rpc* rpc);
|
void RemoveIfNotPending(Rpc* rpc);
|
||||||
|
|
Loading…
Reference in New Issue