Christoph Schütte 2017-11-30 13:18:16 +01:00 committed by GitHub
parent 24f253a2aa
commit 5147af9763
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 110 additions and 19 deletions

View File

@ -24,7 +24,16 @@ message GetSumResponse {
int32 output = 1; int32 output = 1;
} }
message GetSquareRequest {
int32 input = 1;
}
message GetSquareResponse {
int32 output = 1;
}
// Provides information about the gRPC server. // Provides information about the gRPC server.
service Math { service Math {
rpc GetSum(stream GetSumRequest) returns (GetSumResponse); rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
} }

View File

@ -35,6 +35,7 @@ Rpc::Rpc(int method_index,
new_connection_event_{Event::NEW_CONNECTION, this, false}, new_connection_event_{Event::NEW_CONNECTION, this, false},
read_event_{Event::READ, this, false}, read_event_{Event::READ, this, false},
write_event_{Event::WRITE, this, false}, write_event_{Event::WRITE, this, false},
finish_event_{Event::FINISH, this, false},
done_event_{Event::DONE, this, false}, done_event_{Event::DONE, this, false},
handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
InitializeReadersAndWriters(rpc_handler_info_.rpc_type); InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
@ -69,6 +70,12 @@ void Rpc::RequestNextMethodInvocation() {
server_completion_queue_, server_completion_queue_, server_completion_queue_, server_completion_queue_,
&new_connection_event_); &new_connection_event_);
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
service_->RequestAsyncUnary(
method_index_, &server_context_, request_.get(),
streaming_interface(), server_completion_queue_,
server_completion_queue_, &new_connection_event_);
break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
@ -81,6 +88,13 @@ void Rpc::RequestStreamingReadIfNeeded() {
read_event_.pending = true; read_event_.pending = true;
async_reader_interface()->Read(request_.get(), &read_event_); async_reader_interface()->Read(request_.get(), &read_event_);
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
// For NORMAL_RPC we don't have to do anything here, since gRPC
// automatically issues a READ request and places the request into the
// 'Message' we provided to 'RequestAsyncUnary' above.
OnRequest();
OnReadsDone();
break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
@ -88,15 +102,28 @@ void Rpc::RequestStreamingReadIfNeeded() {
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) { void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
response_ = std::move(message); response_ = std::move(message);
write_event_.pending = true; switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK, server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
&write_event_); &finish_event_);
finish_event_.pending = true;
break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
server_async_response_writer_->Finish(*response_.get(),
::grpc::Status::OK, &finish_event_);
finish_event_.pending = true;
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
} }
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get(); return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC:
return server_async_response_writer_.get();
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
@ -108,6 +135,8 @@ Rpc::async_reader_interface() {
switch (rpc_handler_info_.rpc_type) { switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
return server_async_reader_.get(); return server_async_reader_.get();
case ::grpc::internal::RpcMethod::NORMAL_RPC:
LOG(FATAL) << "For NORMAL_RPC no streaming interface exists.";
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }
@ -122,6 +151,8 @@ Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
return &read_event_; return &read_event_;
case Event::WRITE: case Event::WRITE:
return &write_event_; return &write_event_;
case Event::FINISH:
return &finish_event_;
case Event::DONE: case Event::DONE:
return &done_event_; return &done_event_;
} }
@ -139,6 +170,11 @@ void Rpc::InitializeReadersAndWriters(
google::protobuf::Message, google::protobuf::Message>>( google::protobuf::Message, google::protobuf::Message>>(
&server_context_); &server_context_);
break; break;
case ::grpc::internal::RpcMethod::NORMAL_RPC:
server_async_response_writer_ = cartographer::common::make_unique<
::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>(
&server_context_);
break;
default: default:
LOG(FATAL) << "RPC type not implemented."; LOG(FATAL) << "RPC type not implemented.";
} }

View File

@ -36,7 +36,7 @@ namespace framework {
class Service; class Service;
class Rpc { class Rpc {
public: public:
enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE }; enum class Event { NEW_CONNECTION = 0, READ, WRITE, FINISH, DONE };
struct RpcEvent { struct RpcEvent {
const Event event; const Event event;
Rpc* rpc; Rpc* rpc;
@ -79,6 +79,7 @@ class Rpc {
RpcEvent new_connection_event_; RpcEvent new_connection_event_;
RpcEvent read_event_; RpcEvent read_event_;
RpcEvent write_event_; RpcEvent write_event_;
RpcEvent finish_event_;
RpcEvent done_event_; RpcEvent done_event_;
std::unique_ptr<google::protobuf::Message> request_; std::unique_ptr<google::protobuf::Message> request_;
@ -86,6 +87,8 @@ class Rpc {
std::unique_ptr<RpcHandlerInterface> handler_; std::unique_ptr<RpcHandlerInterface> handler_;
std::unique_ptr<::grpc::ServerAsyncResponseWriter<google::protobuf::Message>>
server_async_response_writer_;
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message, std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
google::protobuf::Message>> google::protobuf::Message>>
server_async_reader_; server_async_reader_;

View File

@ -32,7 +32,7 @@ class RpcHandlerInterface {
virtual void SetRpc(Rpc* rpc) = 0; virtual void SetRpc(Rpc* rpc) = 0;
virtual void OnRequestInternal( virtual void OnRequestInternal(
const ::google::protobuf::Message* request) = 0; const ::google::protobuf::Message* request) = 0;
virtual void OnReadsDone() = 0; virtual void OnReadsDone(){};
}; };
using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>( using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(
@ -43,6 +43,7 @@ struct RpcHandlerInfo {
const google::protobuf::Descriptor* response_descriptor; const google::protobuf::Descriptor* response_descriptor;
const RpcHandlerFactory rpc_handler_factory; const RpcHandlerFactory rpc_handler_factory;
const grpc::internal::RpcMethod::RpcType rpc_type; const grpc::internal::RpcMethod::RpcType rpc_type;
const std::string fully_qualified_name;
}; };
} // namespace framework } // namespace framework

View File

@ -19,6 +19,7 @@
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <thread> #include <thread>
@ -53,6 +54,9 @@ class Server {
template <typename RpcHandlerType, typename ServiceType> template <typename RpcHandlerType, typename ServiceType>
void RegisterHandler(const std::string& method_name) { void RegisterHandler(const std::string& method_name) {
std::stringstream fully_qualified_name;
fully_qualified_name << "/" << ServiceType::service_full_name() << "/"
<< method_name;
rpc_handlers_[ServiceType::service_full_name()].emplace( rpc_handlers_[ServiceType::service_full_name()].emplace(
method_name, method_name,
RpcHandlerInfo{ RpcHandlerInfo{
@ -66,7 +70,8 @@ class Server {
return rpc_handler; return rpc_handler;
}, },
RpcType<typename RpcHandlerType::IncomingType, RpcType<typename RpcHandlerType::IncomingType,
typename RpcHandlerType::OutgoingType>::value}); typename RpcHandlerType::OutgoingType>::value,
fully_qualified_name.str()});
} }
private: private:

View File

@ -33,7 +33,7 @@ class MathServerContext : public ExecutionContext {
int additional_increment() { return 10; } int additional_increment() { return 10; }
}; };
class GetServerOptionsHandler class GetSumHandler
: public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> { : public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
public: public:
void OnRequest(const proto::GetSumRequest& request) override { void OnRequest(const proto::GetSumRequest& request) override {
@ -51,6 +51,16 @@ class GetServerOptionsHandler
int sum_ = 0; int sum_ = 0;
}; };
class GetSquareHandler
: public RpcHandler<proto::GetSquareRequest, proto::GetSquareResponse> {
void OnRequest(const proto::GetSquareRequest& request) override {
auto response =
cartographer::common::make_unique<proto::GetSquareResponse>();
response->set_output(request.input() * request.input());
Send(std::move(response));
}
};
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when // TODO(cschuet): Due to the hard-coded part these tests will become flaky when
// run in parallel. It would be nice to find a way to solve that. gRPC also // run in parallel. It would be nice to find a way to solve that. gRPC also
// allows to communicate over UNIX domain sockets. // allows to communicate over UNIX domain sockets.
@ -63,12 +73,19 @@ class ServerTest : public ::testing::Test {
Server::Builder server_builder; Server::Builder server_builder;
server_builder.SetServerAddress(kServerAddress); server_builder.SetServerAddress(kServerAddress);
server_builder.SetNumberOfThreads(kNumThreads); server_builder.SetNumberOfThreads(kNumThreads);
server_builder.RegisterHandler<GetServerOptionsHandler, proto::Math>( server_builder.RegisterHandler<GetSumHandler, proto::Math>("GetSum");
"GetSum"); server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
server_ = server_builder.Build(); server_ = server_builder.Build();
client_channel_ =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
stub_ = proto::Math::NewStub(client_channel_);
} }
std::unique_ptr<Server> server_; std::unique_ptr<Server> server_;
std::shared_ptr<grpc::Channel> client_channel_;
std::unique_ptr<proto::Math::Stub> stub_;
grpc::ClientContext client_context_;
}; };
TEST_F(ServerTest, StartAndStopServerTest) { TEST_F(ServerTest, StartAndStopServerTest) {
@ -81,13 +98,9 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
cartographer::common::make_unique<MathServerContext>()); cartographer::common::make_unique<MathServerContext>());
server_->Start(); server_->Start();
auto channel =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
std::unique_ptr<proto::Math::Stub> stub(proto::Math::NewStub(channel));
grpc::ClientContext context;
proto::GetSumResponse result; proto::GetSumResponse result;
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer( std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer(
stub->GetSum(&context, &result)); stub_->GetSum(&client_context_, &result));
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request; proto::GetSumRequest request;
request.set_input(i); request.set_input(i);
@ -101,6 +114,19 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
server_->Shutdown(); server_->Shutdown();
} }
TEST_F(ServerTest, ProcessUnaryRpcTest) {
server_->Start();
proto::GetSquareResponse result;
proto::GetSquareRequest request;
request.set_input(11);
grpc::Status status = stub_->GetSquare(&client_context_, request, &result);
EXPECT_TRUE(status.ok());
EXPECT_EQ(result.output(), 121);
server_->Shutdown();
}
} // namespace } // namespace
} // namespace framework } // namespace framework
} // namespace cartographer_grpc } // namespace cartographer_grpc

View File

@ -26,13 +26,11 @@ Service::Service(const std::string& service_name,
const std::map<std::string, RpcHandlerInfo>& rpc_handler_infos) const std::map<std::string, RpcHandlerInfo>& rpc_handler_infos)
: rpc_handler_infos_(rpc_handler_infos) { : rpc_handler_infos_(rpc_handler_infos) {
for (const auto& rpc_handler_info : rpc_handler_infos_) { for (const auto& rpc_handler_info : rpc_handler_infos_) {
std::string fully_qualified_method_name =
"/" + service_name + "/" + rpc_handler_info.first;
// The 'handler' below is set to 'nullptr' indicating that we want to // The 'handler' below is set to 'nullptr' indicating that we want to
// handle this method asynchronously. // handle this method asynchronously.
this->AddMethod(new grpc::internal::RpcServiceMethod( this->AddMethod(new grpc::internal::RpcServiceMethod(
fully_qualified_method_name.c_str(), rpc_handler_info.second.rpc_type, rpc_handler_info.second.fully_qualified_name.c_str(),
nullptr /* handler */)); rpc_handler_info.second.rpc_type, nullptr /* handler */));
} }
} }
@ -65,6 +63,9 @@ void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
case Rpc::Event::WRITE: case Rpc::Event::WRITE:
HandleWrite(rpc, ok); HandleWrite(rpc, ok);
break; break;
case Rpc::Event::FINISH:
HandleFinish(rpc, ok);
break;
case Rpc::Event::DONE: case Rpc::Event::DONE:
HandleDone(rpc, ok); HandleDone(rpc, ok);
break; break;
@ -114,12 +115,21 @@ void Service::HandleWrite(Rpc* rpc, bool ok) {
RemoveIfNotPending(rpc); RemoveIfNotPending(rpc);
} }
void Service::HandleFinish(Rpc* rpc, bool ok) {
if (!ok) {
LOG(ERROR) << "Finish failed";
}
RemoveIfNotPending(rpc);
}
void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); } void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }
void Service::RemoveIfNotPending(Rpc* rpc) { void Service::RemoveIfNotPending(Rpc* rpc) {
if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending && if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending &&
!rpc->GetRpcEvent(Rpc::Event::READ)->pending && !rpc->GetRpcEvent(Rpc::Event::READ)->pending &&
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) { !rpc->GetRpcEvent(Rpc::Event::WRITE)->pending &&
!rpc->GetRpcEvent(Rpc::Event::FINISH)->pending) {
active_rpcs_.Remove(rpc); active_rpcs_.Remove(rpc);
} }
} }

View File

@ -45,6 +45,7 @@ class Service : public ::grpc::Service {
void HandleNewConnection(Rpc* rpc, bool ok); void HandleNewConnection(Rpc* rpc, bool ok);
void HandleRead(Rpc* rpc, bool ok); void HandleRead(Rpc* rpc, bool ok);
void HandleWrite(Rpc* rpc, bool ok); void HandleWrite(Rpc* rpc, bool ok);
void HandleFinish(Rpc* rpc, bool ok);
void HandleDone(Rpc* rpc, bool ok); void HandleDone(Rpc* rpc, bool ok);
void RemoveIfNotPending(Rpc* rpc); void RemoveIfNotPending(Rpc* rpc);