Christoph Schütte 2017-11-30 13:18:16 +01:00 committed by GitHub
parent 24f253a2aa
commit 5147af9763
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 110 additions and 19 deletions

View File

@ -24,7 +24,16 @@ message GetSumResponse {
int32 output = 1;
}
message GetSquareRequest {
int32 input = 1;
}
message GetSquareResponse {
int32 output = 1;
}
// Provides information about the gRPC server.
service Math {
rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
}

View File

@ -35,6 +35,7 @@ Rpc::Rpc(int method_index,
new_connection_event_{Event::NEW_CONNECTION, this, false},
read_event_{Event::READ, this, false},
write_event_{Event::WRITE, this, false},
finish_event_{Event::FINISH, this, false},
done_event_{Event::DONE, this, false},
handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
@ -69,6 +70,12 @@ void Rpc::RequestNextMethodInvocation() {
server_completion_queue_, server_completion_queue_,
&new_connection_event_);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
service_->RequestAsyncUnary(
method_index_, &server_context_, request_.get(),
streaming_interface(), server_completion_queue_,
server_completion_queue_, &new_connection_event_);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
@ -81,6 +88,13 @@ void Rpc::RequestStreamingReadIfNeeded() {
read_event_.pending = true;
async_reader_interface()->Read(request_.get(), &read_event_);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
// For NORMAL_RPC we don't have to do anything here, since gRPC
// automatically issues a READ request and places the request into the
// 'Message' we provided to 'RequestAsyncUnary' above.
OnRequest();
OnReadsDone();
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
@ -88,15 +102,28 @@ void Rpc::RequestStreamingReadIfNeeded() {
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
response_ = std::move(message);
write_event_.pending = true;
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
&write_event_);
&finish_event_);
finish_event_.pending = true;
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
server_async_response_writer_->Finish(*response_.get(),
::grpc::Status::OK, &finish_event_);
finish_event_.pending = true;
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC:
return server_async_response_writer_.get();
default:
LOG(FATAL) << "RPC type not implemented.";
}
@ -108,6 +135,8 @@ Rpc::async_reader_interface() {
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC:
LOG(FATAL) << "For NORMAL_RPC no streaming interface exists.";
default:
LOG(FATAL) << "RPC type not implemented.";
}
@ -122,6 +151,8 @@ Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
return &read_event_;
case Event::WRITE:
return &write_event_;
case Event::FINISH:
return &finish_event_;
case Event::DONE:
return &done_event_;
}
@ -139,6 +170,11 @@ void Rpc::InitializeReadersAndWriters(
google::protobuf::Message, google::protobuf::Message>>(
&server_context_);
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
server_async_response_writer_ = cartographer::common::make_unique<
::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>(
&server_context_);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}

View File

@ -36,7 +36,7 @@ namespace framework {
class Service;
class Rpc {
public:
enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE };
enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE };
struct RpcEvent {
const Event event;
Rpc* rpc;
@ -79,6 +79,7 @@ class Rpc {
RpcEvent new_connection_event_;
RpcEvent read_event_;
RpcEvent write_event_;
RpcEvent finish_event_;
RpcEvent done_event_;
std::unique_ptr<google::protobuf::Message> request_;
@ -86,6 +87,8 @@ class Rpc {
std::unique_ptr<RpcHandlerInterface> handler_;
std::unique_ptr<::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>
server_async_response_writer_;
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
google::protobuf::Message>>
server_async_reader_;

View File

@ -32,7 +32,7 @@ class RpcHandlerInterface {
virtual void SetRpc(Rpc* rpc) = 0;
virtual void OnRequestInternal(
const ::google::protobuf::Message* request) = 0;
virtual void OnReadsDone() = 0;
virtual void OnReadsDone(){};
};
using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(
@ -43,6 +43,7 @@ struct RpcHandlerInfo {
const google::protobuf::Descriptor* response_descriptor;
const RpcHandlerFactory rpc_handler_factory;
const grpc::internal::RpcMethod::RpcType rpc_type;
const std::string fully_qualified_name;
};
} // namespace framework

View File

@ -19,6 +19,7 @@
#include <cstddef>
#include <memory>
#include <sstream>
#include <string>
#include <thread>
@ -53,6 +54,9 @@ class Server {
template <typename RpcHandlerType, typename ServiceType>
void RegisterHandler(const std::string& method_name) {
std::stringstream fully_qualified_name;
fully_qualified_name << "/" << ServiceType::service_full_name() << "/"
<< method_name;
rpc_handlers_[ServiceType::service_full_name()].emplace(
method_name,
RpcHandlerInfo{
@ -66,7 +70,8 @@ class Server {
return rpc_handler;
},
RpcType<typename RpcHandlerType::IncomingType,
typename RpcHandlerType::OutgoingType>::value});
typename RpcHandlerType::OutgoingType>::value,
fully_qualified_name.str()});
}
private:

View File

@ -33,7 +33,7 @@ class MathServerContext : public ExecutionContext {
int additional_increment() { return 10; }
};
class GetServerOptionsHandler
class GetSumHandler
: public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
public:
void OnRequest(const proto::GetSumRequest& request) override {
@ -51,6 +51,16 @@ class GetServerOptionsHandler
int sum_ = 0;
};
class GetSquareHandler
: public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> {
void OnRequest(const proto::GetSquareRequest& request) override {
auto response =
cartographer::common::make_unique<proto::GetSquareResponse>();
response->set_output(request.input() * request.input());
Send(std::move(response));
}
};
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when
// run in parallel. It would be nice to find a way to solve that. gRPC also
// allows to communicate over UNIX domain sockets.
@ -63,12 +73,19 @@ class ServerTest : public ::testing::Test {
Server::Builder server_builder;
server_builder.SetServerAddress(kServerAddress);
server_builder.SetNumberOfThreads(kNumThreads);
server_builder.RegisterHandler<GetServerOptionsHandler, proto::Math>(
"GetSum");
server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum");
server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
server_ = server_builder.Build();
client_channel_ =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
stub_ = proto::Math::NewStub(client_channel_);
}
std::unique_ptr<Server> server_;
std::shared_ptr<grpc::Channel> client_channel_;
std::unique_ptr<proto::Math::Stub> stub_;
grpc::ClientContext client_context_;
};
TEST_F(ServerTest, StartAndStopServerTest) {
@ -81,13 +98,9 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
cartographer::common::make_unique<MathServerContext>());
server_->Start();
auto channel =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
std::unique_ptr<proto::Math::Stub> stub(proto::Math::NewStub(channel));
grpc::ClientContext context;
proto::GetSumResponse result;
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer(
stub->GetSum(&context, &result));
stub_->GetSum(&client_context_, &result));
for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request;
request.set_input(i);
@ -101,6 +114,19 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
server_->Shutdown();
}
TEST_F(ServerTest, ProcessUnaryRpcTest) {
server_->Start();
proto::GetSquareResponse result;
proto::GetSquareRequest request;
request.set_input(11);
grpc::Status status = stub_->GetSquare(&client_context_, request, &result);
EXPECT_TRUE(status.ok());
EXPECT_EQ(result.output(), 121);
server_->Shutdown();
}
} // namespace
} // namespace framework
} // namespace cartographer_grpc

View File

@ -26,13 +26,11 @@ Service::Service(const std::string& service_name,
const std::map<std::string, RpcHandlerInfo>& rpc_handler_infos)
: rpc_handler_infos_(rpc_handler_infos) {
for (const auto& rpc_handler_info : rpc_handler_infos_) {
std::string fully_qualified_method_name =
"/" + service_name + "/" + rpc_handler_info.first;
// The 'handler' below is set to 'nullptr' indicating that we want to
// handle this method asynchronously.
this->AddMethod(new grpc::internal::RpcServiceMethod(
fully_qualified_method_name.c_str(), rpc_handler_info.second.rpc_type,
nullptr /* handler */));
rpc_handler_info.second.fully_qualified_name.c_str(),
rpc_handler_info.second.rpc_type, nullptr /* handler */));
}
}
@ -65,6 +63,9 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
case Rpc::Event::WRITE:
HandleWrite(rpc, ok);
break;
case Rpc::Event::FINISH:
HandleFinish(rpc, ok);
break;
case Rpc::Event::DONE:
HandleDone(rpc, ok);
break;
@ -114,12 +115,21 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
RemoveIfNotPending(rpc);
}
void Service::HandleFinish(Rpc* rpc, bool ok) {
if (!ok) {
LOG(ERROR) << "Finish failed";
}
RemoveIfNotPending(rpc);
}
void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }
void Service::RemoveIfNotPending(Rpc* rpc) {
if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending &&
!rpc->GetRpcEvent(Rpc::Event::READ)->pending &&
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) {
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending &&
!rpc->GetRpcEvent(Rpc::Event::FINISH)->pending) {
active_rpcs_.Remove(rpc);
}
}

View File

@ -45,6 +45,7 @@ class Service : public ::grpc::Service {
void HandleNewConnection(Rpc* rpc, bool ok);
void HandleRead(Rpc* rpc, bool ok);
void HandleWrite(Rpc* rpc, bool ok);
void HandleFinish(Rpc* rpc, bool ok);
void HandleDone(Rpc* rpc, bool ok);
void RemoveIfNotPending(Rpc* rpc);