Christoph Schütte 2017-11-29 10:40:26 +01:00 committed by Wally B. Feed
parent 02359a98ae
commit 3a46804393
9 changed files with 290 additions and 110 deletions

View File

@ -16,15 +16,15 @@ syntax = "proto3";
package cartographer_grpc.framework.proto; package cartographer_grpc.framework.proto;
message Request { message GetSumRequest {
int32 input = 1; int32 input = 1;
} }
message Response { message GetSumResponse {
int32 output = 1; int32 output = 1;
} }
// Provides information about the gRPC server. // Provides information about the gRPC server.
service Math { service Math {
rpc GetSum(stream Request) returns (Response); rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
} }

View File

@ -15,6 +15,7 @@
*/ */
#include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc.h"
#include "cartographer_grpc/framework/service.h"
#include "cartographer/common/make_unique.h" #include "cartographer/common/make_unique.h"
#include "glog/logging.h" #include "glog/logging.h"
@ -28,19 +29,65 @@ Rpc::Rpc(int method_index,
: method_index_(method_index), : method_index_(method_index),
server_completion_queue_(server_completion_queue), server_completion_queue_(server_completion_queue),
rpc_handler_info_(rpc_handler_info), rpc_handler_info_(rpc_handler_info),
new_connection_state_{State::NEW_CONNECTION, service, this}, service_(service),
read_state_{State::READ, service, this}, new_connection_event_{Event::NEW_CONNECTION, this, false},
write_state_{State::WRITE, service, this}, read_event_{Event::READ, this, false},
done_state_{State::DONE, service, this} { write_event_{Event::WRITE, this, false},
InitializeResponders(rpc_handler_info_.rpc_type); done_event_{Event::DONE, this, false},
handler_(rpc_handler_info_.rpc_handler_factory(this)) {
InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
// Initialize the prototypical request and response messages.
request_.reset(::google::protobuf::MessageFactory::generated_factory()
->GetPrototype(rpc_handler_info_.request_descriptor)
->New());
response_.reset(::google::protobuf::MessageFactory::generated_factory()
->GetPrototype(rpc_handler_info_.response_descriptor)
->New());
} }
::grpc::ServerCompletionQueue* Rpc::server_completion_queue() { std::unique_ptr<Rpc> Rpc::Clone() {
return server_completion_queue_; return cartographer::common::make_unique<Rpc>(
method_index_, server_completion_queue_, rpc_handler_info_, service_);
} }
::grpc::internal::RpcMethod::RpcType Rpc::rpc_type() const { void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); }
return rpc_handler_info_.rpc_type;
void Rpc::OnReadsDone() { handler_->OnReadsDone(); }
void Rpc::RequestNextMethodInvocation() {
done_event_.pending = true;
new_connection_event_.pending = true;
server_context_.AsyncNotifyWhenDone(&done_event_);
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
service_->RequestAsyncClientStreaming(
method_index_, &server_context_, streaming_interface(),
server_completion_queue_, server_completion_queue_,
&new_connection_event_);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::RequestStreamingReadIfNeeded() {
// For request-streaming RPCs ask the client to start sending requests.
switch (rpc_handler_info_.rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
read_event_.pending = true;
async_reader_interface()->Read(request_.get(), &read_event_);
break;
default:
LOG(FATAL) << "RPC type not implemented.";
}
}
void Rpc::Write(std::unique_ptr<::google::protobuf::Message> message) {
response_ = std::move(message);
write_event_.pending = true;
server_async_reader_->Finish(*response_.get(), ::grpc::Status::OK,
&write_event_);
} }
::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() { ::grpc::internal::ServerAsyncStreamingInterface* Rpc::streaming_interface() {
@ -53,23 +100,35 @@ Rpc::Rpc(int method_index,
LOG(FATAL) << "Never reached."; LOG(FATAL) << "Never reached.";
} }
Rpc::RpcState* Rpc::GetRpcState(State state) { ::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
switch (state) { Rpc::async_reader_interface() {
case State::NEW_CONNECTION: switch (rpc_handler_info_.rpc_type) {
return &new_connection_state_; case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
case State::READ: return server_async_reader_.get();
return &read_state_; default:
case State::WRITE: LOG(FATAL) << "RPC type not implemented.";
return &write_state_; }
case State::DONE: LOG(FATAL) << "Never reached.";
return &done_state_; }
Rpc::RpcEvent* Rpc::GetRpcEvent(Event event) {
switch (event) {
case Event::NEW_CONNECTION:
return &new_connection_event_;
case Event::READ:
return &read_event_;
case Event::WRITE:
return &write_event_;
case Event::DONE:
return &done_event_;
} }
LOG(FATAL) << "Never reached."; LOG(FATAL) << "Never reached.";
} }
ActiveRpcs::ActiveRpcs() : lock_() {} ActiveRpcs::ActiveRpcs() : lock_() {}
void Rpc::InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type) { void Rpc::InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type) {
switch (rpc_type) { switch (rpc_type) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: case ::grpc::internal::RpcMethod::CLIENT_STREAMING:
server_async_reader_ = server_async_reader_ =

View File

@ -21,10 +21,11 @@
#include <unordered_set> #include <unordered_set>
#include "cartographer/common/mutex.h" #include "cartographer/common/mutex.h"
#include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/rpc_handler_interface.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "grpc++/impl/codegen/async_stream.h" #include "grpc++/impl/codegen/async_stream.h"
#include "grpc++/impl/codegen/async_unary_call.h"
#include "grpc++/impl/codegen/proto_utils.h" #include "grpc++/impl/codegen/proto_utils.h"
#include "grpc++/impl/codegen/service_type.h" #include "grpc++/impl/codegen/service_type.h"
@ -34,45 +35,54 @@ namespace framework {
class Service; class Service;
class Rpc { class Rpc {
public: public:
enum class State { NEW_CONNECTION = 0, READ, WRITE, DONE }; enum class Event { NEW_CONNECTION = 0, READ, WRITE, DONE };
struct RpcState { struct RpcEvent {
const State state; const Event event;
Service* service;
Rpc* rpc; Rpc* rpc;
// Indicates whether the event is pending completion. E.g. 'event = READ'
// and 'pending = true' means that a read has been requested but hasn't
// completed yet. While 'pending = false' indicates, that the read has
// completed and currently no read is in-flight.
bool pending;
}; };
Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue,
const RpcHandlerInfo& rpc_handler_info, Service* service); const RpcHandlerInfo& rpc_handler_info, Service* service);
std::unique_ptr<Rpc> Clone();
int method_index() const { return method_index_; } void OnRequest();
::grpc::ServerCompletionQueue* server_completion_queue(); void OnReadsDone();
::grpc::internal::RpcMethod::RpcType rpc_type() const; void RequestNextMethodInvocation();
::grpc::ServerContext* server_context() { return &server_context_; } void RequestStreamingReadIfNeeded();
::grpc::internal::ServerAsyncStreamingInterface* streaming_interface(); void Write(std::unique_ptr<::google::protobuf::Message> message);
RpcState* GetRpcState(State state); Service* service() { return service_; }
const RpcHandlerInfo& rpc_handler_info() const { return rpc_handler_info_; } RpcEvent* GetRpcEvent(Event event);
::google::protobuf::Message* request() { return request_.get(); }
::google::protobuf::Message* response() { return response_.get(); }
private: private:
Rpc(const Rpc&) = delete; Rpc(const Rpc&) = delete;
Rpc& operator=(const Rpc&) = delete; Rpc& operator=(const Rpc&) = delete;
void InitializeResponders(::grpc::internal::RpcMethod::RpcType rpc_type); void InitializeReadersAndWriters(
::grpc::internal::RpcMethod::RpcType rpc_type);
::grpc::internal::AsyncReaderInterface<::google::protobuf::Message>*
async_reader_interface();
::grpc::internal::ServerAsyncStreamingInterface* streaming_interface();
int method_index_; int method_index_;
::grpc::ServerCompletionQueue* server_completion_queue_; ::grpc::ServerCompletionQueue* server_completion_queue_;
RpcHandlerInfo rpc_handler_info_; RpcHandlerInfo rpc_handler_info_;
Service* service_;
::grpc::ServerContext server_context_; ::grpc::ServerContext server_context_;
RpcState new_connection_state_; RpcEvent new_connection_event_;
RpcState read_state_; RpcEvent read_event_;
RpcState write_state_; RpcEvent write_event_;
RpcState done_state_; RpcEvent done_event_;
std::unique_ptr<google::protobuf::Message> request_; std::unique_ptr<google::protobuf::Message> request_;
std::unique_ptr<google::protobuf::Message> response_; std::unique_ptr<google::protobuf::Message> response_;
std::unique_ptr<RpcHandlerInterface> handler_;
std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message, std::unique_ptr<::grpc::ServerAsyncReader<google::protobuf::Message,
google::protobuf::Message>> google::protobuf::Message>>
server_async_reader_; server_async_reader_;

View File

@ -17,30 +17,16 @@
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H
#include "cartographer_grpc/framework/rpc.h"
#include "cartographer_grpc/framework/rpc_handler_interface.h"
#include "cartographer_grpc/framework/type_traits.h" #include "cartographer_grpc/framework/type_traits.h"
#include "glog/logging.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
namespace cartographer_grpc { namespace cartographer_grpc {
namespace framework { namespace framework {
class Rpc;
class RpcHandlerInterface {
public:
virtual ~RpcHandlerInterface() = default;
virtual void SetRpc(Rpc* rpc) = 0;
};
using RpcHandlerFactory =
std::function<std::unique_ptr<RpcHandlerInterface>(Rpc*)>;
struct RpcHandlerInfo {
const google::protobuf::Descriptor* request_descriptor;
const google::protobuf::Descriptor* response_descriptor;
const RpcHandlerFactory rpc_handler_factory;
const grpc::internal::RpcMethod::RpcType rpc_type;
};
template <typename Incoming, typename Outgoing> template <typename Incoming, typename Outgoing>
class RpcHandler : public RpcHandlerInterface { class RpcHandler : public RpcHandlerInterface {
public: public:
@ -50,6 +36,14 @@ class RpcHandler : public RpcHandlerInterface {
using ResponseType = StripStream<Outgoing>; using ResponseType = StripStream<Outgoing>;
void SetRpc(Rpc* rpc) override { rpc_ = rpc; } void SetRpc(Rpc* rpc) override { rpc_ = rpc; }
void OnRequestInternal(const ::google::protobuf::Message* request) override {
DCHECK(dynamic_cast<const RequestType*>(request));
OnRequest(static_cast<const RequestType&>(*request));
}
virtual void OnRequest(const RequestType& request) = 0;
void Send(std::unique_ptr<ResponseType> response) {
rpc_->Write(std::move(response));
}
private: private:
Rpc* rpc_; Rpc* rpc_;

View File

@ -0,0 +1,49 @@
/*
* Copyright 2017 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#include "google/protobuf/message.h"
#include "grpc++/grpc++.h"
namespace cartographer_grpc {
namespace framework {
class Rpc;
class RpcHandlerInterface {
public:
virtual ~RpcHandlerInterface() = default;
virtual void SetRpc(Rpc* rpc) = 0;
virtual void OnRequestInternal(
const ::google::protobuf::Message* request) = 0;
virtual void OnReadsDone() = 0;
};
using RpcHandlerFactory =
std::function<std::unique_ptr<RpcHandlerInterface>(Rpc*)>;
struct RpcHandlerInfo {
const google::protobuf::Descriptor* request_descriptor;
const google::protobuf::Descriptor* response_descriptor;
const RpcHandlerFactory rpc_handler_factory;
const grpc::internal::RpcMethod::RpcType rpc_type;
};
} // namespace framework
} // namespace cartographer_grpc
#endif // CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H

View File

@ -65,8 +65,9 @@ void Server::RunCompletionQueue(
bool ok; bool ok;
void* tag; void* tag;
while (completion_queue->Next(&tag, &ok)) { while (completion_queue->Next(&tag, &ok)) {
auto* rpc_state = static_cast<Rpc::RpcState*>(tag); auto* rpc_event = static_cast<Rpc::RpcEvent*>(tag);
rpc_state->service->HandleEvent(rpc_state->state, rpc_state->rpc, ok); rpc_event->rpc->service()->HandleEvent(rpc_event->event, rpc_event->rpc,
ok);
} }
} }

View File

@ -28,17 +28,68 @@ namespace framework {
namespace { namespace {
class GetServerOptionsHandler class GetServerOptionsHandler
: public RpcHandler<Stream<proto::Request>, proto::Response> {}; : public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
public:
void OnRequest(const proto::GetSumRequest& request) override {
sum_ += request.input();
}
TEST(ServerTest, StartServerTest) { void OnReadsDone() override {
Server::Builder server_builder; auto response = cartographer::common::make_unique<proto::GetSumResponse>();
server_builder.SetServerAddress("0.0.0.0:50051"); response->set_output(sum_);
server_builder.SetNumberOfThreads(1); Send(std::move(response));
server_builder.RegisterHandler<GetServerOptionsHandler, proto::Math>( }
"GetSum");
std::unique_ptr<Server> server = server_builder.Build(); private:
server->Start(); int sum_ = 0;
server->Shutdown(); };
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when
// run in parallel. It would be nice to find a way to solve that. gRPC also
// allows to communicate over UNIX domain sockets.
const std::string kServerAddress = "localhost:50051";
const std::size_t kNumThreads = 1;
class ServerTest : public ::testing::Test {
protected:
void SetUp() override {
Server::Builder server_builder;
server_builder.SetServerAddress(kServerAddress);
server_builder.SetNumberOfThreads(kNumThreads);
server_builder.RegisterHandler<GetServerOptionsHandler, proto::Math>(
"GetSum");
server_ = server_builder.Build();
}
std::unique_ptr<Server> server_;
};
TEST_F(ServerTest, StartAndStopServerTest) {
server_->Start();
server_->Shutdown();
}
TEST_F(ServerTest, ProcessRpcStreamTest) {
server_->Start();
auto channel =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
std::unique_ptr<proto::Math::Stub> stub(proto::Math::NewStub(channel));
grpc::ClientContext context;
proto::GetSumResponse result;
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest> > writer(
stub->GetSum(&context, &result));
for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request;
request.set_input(i);
EXPECT_TRUE(writer->Write(request));
}
writer->WritesDone();
grpc::Status status = writer->Finish();
EXPECT_TRUE(status.ok());
EXPECT_EQ(result.output(), 3);
server_->Shutdown();
} }
} // namespace } // namespace

View File

@ -44,8 +44,7 @@ void Service::StartServing(
Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>( Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>(
i, completion_queue_thread.completion_queue(), i, completion_queue_thread.completion_queue(),
rpc_handler_info.second, this)); rpc_handler_info.second, this));
RequestNextMethodInvocation(i, rpc, rpc->RequestNextMethodInvocation();
completion_queue_thread.completion_queue());
} }
++i; ++i;
} }
@ -53,16 +52,19 @@ void Service::StartServing(
void Service::StopServing() { shutting_down_ = true; } void Service::StopServing() { shutting_down_ = true; }
void Service::HandleEvent(Rpc::State state, Rpc* rpc, bool ok) { void Service::HandleEvent(Rpc::Event event, Rpc* rpc, bool ok) {
switch (state) { rpc->GetRpcEvent(event)->pending = false;
case Rpc::State::NEW_CONNECTION: switch (event) {
case Rpc::Event::NEW_CONNECTION:
HandleNewConnection(rpc, ok); HandleNewConnection(rpc, ok);
break; break;
case Rpc::State::READ: case Rpc::Event::READ:
HandleRead(rpc, ok);
break; break;
case Rpc::State::WRITE: case Rpc::Event::WRITE:
HandleWrite(rpc, ok);
break; break;
case Rpc::State::DONE: case Rpc::Event::DONE:
HandleDone(rpc, ok); HandleDone(rpc, ok);
break; break;
} }
@ -70,7 +72,9 @@ void Service::HandleEvent(Rpc::State state, Rpc* rpc, bool ok) {
void Service::HandleNewConnection(Rpc* rpc, bool ok) { void Service::HandleNewConnection(Rpc* rpc, bool ok) {
if (shutting_down_) { if (shutting_down_) {
LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs."; if (ok) {
LOG(WARNING) << "Server shutting down. Refusing to handle new RPCs.";
}
active_rpcs_.Remove(rpc); active_rpcs_.Remove(rpc);
return; return;
} }
@ -80,33 +84,42 @@ void Service::HandleNewConnection(Rpc* rpc, bool ok) {
active_rpcs_.Remove(rpc); active_rpcs_.Remove(rpc);
} }
// TODO(cschuet): Request next read for the new connection. if (ok) {
// For request-streaming RPCs ask the client to start sending requests.
rpc->RequestStreamingReadIfNeeded();
}
// Create new active rpc to handle next connection. // Create new active rpc to handle next connection and register it for the
Rpc* next_rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>( // incoming connection.
rpc->method_index(), rpc->server_completion_queue(), active_rpcs_.Add(rpc->Clone())->RequestNextMethodInvocation();
rpc->rpc_handler_info(), this));
RequestNextMethodInvocation(rpc->method_index(), next_rpc,
rpc->server_completion_queue());
} }
void Service::HandleDone(Rpc* rpc, bool ok) { LOG(FATAL) << "Not implemented"; } void Service::HandleRead(Rpc* rpc, bool ok) {
if (ok) {
rpc->OnRequest();
rpc->RequestStreamingReadIfNeeded();
return;
}
void Service::RequestNextMethodInvocation( // Reads completed.
int method_index, Rpc* rpc, rpc->OnReadsDone();
::grpc::ServerCompletionQueue* completion_queue) { }
rpc->server_context()->AsyncNotifyWhenDone(
rpc->GetRpcState(Rpc::State::DONE)); void Service::HandleWrite(Rpc* rpc, bool ok) {
switch (rpc->rpc_type()) { if (!ok) {
case ::grpc::internal::RpcMethod::CLIENT_STREAMING: LOG(ERROR) << "Write failed";
RequestAsyncClientStreaming(method_index, rpc->server_context(), }
rpc->streaming_interface(), completion_queue,
completion_queue, RemoveIfNotPending(rpc);
rpc->GetRpcState(Rpc::State::NEW_CONNECTION)); }
break;
default: void Service::HandleDone(Rpc* rpc, bool ok) { RemoveIfNotPending(rpc); }
LOG(FATAL) << "RPC type not implemented.";
void Service::RemoveIfNotPending(Rpc* rpc) {
if (!rpc->GetRpcEvent(Rpc::Event::DONE)->pending &&
!rpc->GetRpcEvent(Rpc::Event::READ)->pending &&
!rpc->GetRpcEvent(Rpc::Event::WRITE)->pending) {
active_rpcs_.Remove(rpc);
} }
} }

View File

@ -17,6 +17,7 @@
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H #define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H
#include "cartographer_grpc/framework/completion_queue_thread.h"
#include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc.h"
#include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/rpc_handler.h"
#include "grpc++/impl/codegen/service_type.h" #include "grpc++/impl/codegen/service_type.h"
@ -30,20 +31,22 @@ namespace framework {
// 'Rpc' handler objects. // 'Rpc' handler objects.
class Service : public ::grpc::Service { class Service : public ::grpc::Service {
public: public:
friend class Rpc;
Service(const std::string& service_name, Service(const std::string& service_name,
const std::map<std::string, RpcHandlerInfo>& rpc_handlers); const std::map<std::string, RpcHandlerInfo>& rpc_handlers);
void StartServing(std::vector<CompletionQueueThread>& completion_queues); void StartServing(std::vector<CompletionQueueThread>& completion_queues);
void HandleEvent(Rpc::State state, Rpc* rpc, bool ok); void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok);
void StopServing(); void StopServing();
private: private:
void RequestNextMethodInvocation(
int method_index, Rpc* rpc,
::grpc::ServerCompletionQueue* completion_queue);
void HandleNewConnection(Rpc* rpc, bool ok); void HandleNewConnection(Rpc* rpc, bool ok);
void HandleRead(Rpc* rpc, bool ok);
void HandleWrite(Rpc* rpc, bool ok);
void HandleDone(Rpc* rpc, bool ok); void HandleDone(Rpc* rpc, bool ok);
void RemoveIfNotPending(Rpc* rpc);
std::map<std::string, RpcHandlerInfo> rpc_handler_infos_; std::map<std::string, RpcHandlerInfo> rpc_handler_infos_;
ActiveRpcs active_rpcs_; ActiveRpcs active_rpcs_;
bool shutting_down_ = false; bool shutting_down_ = false;