Introduce framework::Client (#867)

Introduces a framework::Client class that makes it more convenient to call gRPC methods.

[RFC=0002](https://github.com/googlecartographer/rfcs/blob/master/text/0002-cloud-based-mapping-1.md)
master
Christoph Schütte 2018-01-31 17:45:57 +01:00 committed by Wally B. Feed
parent 0440761474
commit a749d28a67
4 changed files with 202 additions and 38 deletions

View File

@ -0,0 +1,171 @@
/*
* Copyright 2018 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H
#include "grpc++/grpc++.h"
#include "grpc++/impl/codegen/client_unary_call.h"
#include "grpc++/impl/codegen/sync_stream.h"
namespace cartographer_grpc {
namespace framework {
template <typename RpcHandlerType>
class Client {
public:
Client(std::shared_ptr<grpc::Channel> channel)
: channel_(channel),
rpc_method_name_(
RpcHandlerInterface::Instantiate<RpcHandlerType>()->method_name()),
rpc_method_(rpc_method_name_.c_str(),
RpcType<typename RpcHandlerType::IncomingType,
typename RpcHandlerType::OutgoingType>::value,
channel_) {}
bool Read(typename RpcHandlerType::ResponseType* response) {
switch (rpc_method_.method_type()) {
case grpc::internal::RpcMethod::BIDI_STREAMING:
InstantiateClientReaderWriterIfNeeded();
return client_reader_writer_->Read(response);
case grpc::internal::RpcMethod::SERVER_STREAMING:
CHECK(client_reader_);
return client_reader_->Read(response);
default:
LOG(FATAL) << "Not implemented.";
}
}
bool Write(const typename RpcHandlerType::RequestType& request) {
switch (rpc_method_.method_type()) {
case grpc::internal::RpcMethod::NORMAL_RPC:
return MakeBlockingUnaryCall(request, &response_).ok();
case grpc::internal::RpcMethod::CLIENT_STREAMING:
InstantiateClientWriterIfNeeded();
return client_writer_->Write(request);
case grpc::internal::RpcMethod::BIDI_STREAMING:
InstantiateClientReaderWriterIfNeeded();
return client_reader_writer_->Write(request);
case grpc::internal::RpcMethod::SERVER_STREAMING:
InstantiateClientReader(request);
return true;
}
LOG(FATAL) << "Not reached.";
}
bool WritesDone() {
switch (rpc_method_.method_type()) {
case grpc::internal::RpcMethod::CLIENT_STREAMING:
InstantiateClientWriterIfNeeded();
return client_writer_->WritesDone();
case grpc::internal::RpcMethod::BIDI_STREAMING:
InstantiateClientReaderWriterIfNeeded();
return client_reader_writer_->WritesDone();
default:
LOG(FATAL) << "Not implemented.";
}
}
grpc::Status Finish() {
switch (rpc_method_.method_type()) {
case grpc::internal::RpcMethod::CLIENT_STREAMING:
InstantiateClientWriterIfNeeded();
return client_writer_->Finish();
case grpc::internal::RpcMethod::BIDI_STREAMING:
InstantiateClientReaderWriterIfNeeded();
return client_reader_writer_->Finish();
case grpc::internal::RpcMethod::SERVER_STREAMING:
CHECK(client_reader_);
return client_reader_->Finish();
default:
LOG(FATAL) << "Not implemented.";
}
}
const typename RpcHandlerType::ResponseType& response() {
CHECK(rpc_method_.method_type() == grpc::internal::RpcMethod::NORMAL_RPC ||
rpc_method_.method_type() ==
grpc::internal::RpcMethod::CLIENT_STREAMING);
return response_;
}
private:
void InstantiateClientWriterIfNeeded() {
CHECK_EQ(rpc_method_.method_type(),
grpc::internal::RpcMethod::CLIENT_STREAMING);
if (!client_writer_) {
client_writer_.reset(
grpc::internal::ClientWriterFactory<
typename RpcHandlerType::RequestType>::Create(channel_.get(),
rpc_method_,
&client_context_,
&response_));
}
}
void InstantiateClientReaderWriterIfNeeded() {
CHECK_EQ(rpc_method_.method_type(),
grpc::internal::RpcMethod::BIDI_STREAMING);
if (!client_reader_writer_) {
client_reader_writer_.reset(
grpc::internal::ClientReaderWriterFactory<
typename RpcHandlerType::RequestType,
typename RpcHandlerType::ResponseType>::Create(channel_.get(),
rpc_method_,
&client_context_));
}
}
void InstantiateClientReader(
const typename RpcHandlerType::RequestType& request) {
CHECK_EQ(rpc_method_.method_type(),
grpc::internal::RpcMethod::SERVER_STREAMING);
client_reader_.reset(
grpc::internal::ClientReaderFactory<
typename RpcHandlerType::ResponseType>::Create(channel_.get(),
rpc_method_,
&client_context_,
request));
}
grpc::Status MakeBlockingUnaryCall(
const typename RpcHandlerType::RequestType& request,
typename RpcHandlerType::ResponseType* response) {
CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::NORMAL_RPC);
return ::grpc::internal::BlockingUnaryCall(
channel_.get(), rpc_method_, &client_context_, request, response);
}
std::shared_ptr<grpc::Channel> channel_;
grpc::ClientContext client_context_;
const std::string rpc_method_name_;
const ::grpc::internal::RpcMethod rpc_method_;
std::unique_ptr<grpc::ClientWriter<typename RpcHandlerType::RequestType>>
client_writer_;
std::unique_ptr<
grpc::ClientReaderWriter<typename RpcHandlerType::RequestType,
typename RpcHandlerType::ResponseType>>
client_reader_writer_;
std::unique_ptr<grpc::ClientReader<typename RpcHandlerType::ResponseType>>
client_reader_;
typename RpcHandlerType::ResponseType response_;
};
} // namespace framework
} // namespace cartographer_grpc
#endif // CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H

View File

@ -17,6 +17,7 @@
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#include "cartographer/common/make_unique.h"
#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/execution_context.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
@ -40,6 +41,10 @@ class RpcHandlerInterface {
const ::google::protobuf::Message* request) = 0; const ::google::protobuf::Message* request) = 0;
virtual void OnReadsDone(){}; virtual void OnReadsDone(){};
virtual void OnFinish(){}; virtual void OnFinish(){};
template <class RpcHandlerType>
static std::unique_ptr<RpcHandlerType> Instantiate() {
return cartographer::common::make_unique<RpcHandlerType>();
}
}; };
using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>( using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(

View File

@ -57,7 +57,8 @@ class Server {
template <typename RpcHandlerType> template <typename RpcHandlerType>
void RegisterHandler() { void RegisterHandler() {
std::string method_full_name = GetMethodFullName<RpcHandlerType>(); std::string method_full_name =
RpcHandlerInterface::Instantiate<RpcHandlerType>()->method_name();
std::string service_full_name; std::string service_full_name;
std::string method_name; std::string method_name;
std::tie(service_full_name, method_name) = std::tie(service_full_name, method_name) =
@ -83,11 +84,6 @@ class Server {
private: private:
using ServiceInfo = std::map<std::string, RpcHandlerInfo>; using ServiceInfo = std::map<std::string, RpcHandlerInfo>;
template <typename RpcHandlerType>
std::string GetMethodFullName() {
auto handler = cartographer::common::make_unique<const RpcHandlerType>();
return handler->method_name();
}
std::tuple<std::string /* service_full_name */, std::tuple<std::string /* service_full_name */,
std::string /* method_name */> std::string /* method_name */>
ParseMethodFullName(const std::string& method_full_name); ParseMethodFullName(const std::string& method_full_name);

View File

@ -18,6 +18,7 @@
#include <future> #include <future>
#include "cartographer_grpc/framework/client.h"
#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" #include "cartographer_grpc/framework/proto/math_service.grpc.pb.h"
#include "cartographer_grpc/framework/proto/math_service.pb.h" #include "cartographer_grpc/framework/proto/math_service.pb.h"
@ -156,13 +157,10 @@ class ServerTest : public ::testing::Test {
client_channel_ = client_channel_ =
grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials()); grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials());
stub_ = proto::Math::NewStub(client_channel_);
} }
std::unique_ptr<Server> server_; std::unique_ptr<Server> server_;
std::shared_ptr<grpc::Channel> client_channel_; std::shared_ptr<grpc::Channel> client_channel_;
std::unique_ptr<proto::Math::Stub> stub_;
grpc::ClientContext client_context_;
}; };
TEST_F(ServerTest, StartAndStopServerTest) { TEST_F(ServerTest, StartAndStopServerTest) {
@ -175,18 +173,15 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
cartographer::common::make_unique<MathServerContext>()); cartographer::common::make_unique<MathServerContext>());
server_->Start(); server_->Start();
proto::GetSumResponse result; Client<GetSumHandler> client(client_channel_);
std::unique_ptr<grpc::ClientWriter<proto::GetSumRequest>> writer(
stub_->GetSum(&client_context_, &result));
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request; proto::GetSumRequest request;
request.set_input(i); request.set_input(i);
EXPECT_TRUE(writer->Write(request)); EXPECT_TRUE(client.Write(request));
} }
writer->WritesDone(); EXPECT_TRUE(client.WritesDone());
grpc::Status status = writer->Finish(); EXPECT_TRUE(client.Finish().ok());
EXPECT_TRUE(status.ok()); EXPECT_EQ(client.response().output(), 33);
EXPECT_EQ(result.output(), 33);
server_->Shutdown(); server_->Shutdown();
} }
@ -194,12 +189,11 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
TEST_F(ServerTest, ProcessUnaryRpcTest) { TEST_F(ServerTest, ProcessUnaryRpcTest) {
server_->Start(); server_->Start();
proto::GetSquareResponse result; Client<GetSquareHandler> client(client_channel_);
proto::GetSquareRequest request; proto::GetSquareRequest request;
request.set_input(11); request.set_input(11);
grpc::Status status = stub_->GetSquare(&client_context_, request, &result); EXPECT_TRUE(client.Write(request));
EXPECT_TRUE(status.ok()); EXPECT_EQ(client.response().output(), 121);
EXPECT_EQ(result.output(), 121);
server_->Shutdown(); server_->Shutdown();
} }
@ -207,22 +201,21 @@ TEST_F(ServerTest, ProcessUnaryRpcTest) {
TEST_F(ServerTest, ProcessBidiStreamingRpcTest) { TEST_F(ServerTest, ProcessBidiStreamingRpcTest) {
server_->Start(); server_->Start();
auto reader_writer = stub_->GetRunningSum(&client_context_); Client<GetRunningSumHandler> client(client_channel_);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
proto::GetSumRequest request; proto::GetSumRequest request;
request.set_input(i); request.set_input(i);
EXPECT_TRUE(reader_writer->Write(request)); EXPECT_TRUE(client.Write(request));
} }
reader_writer->WritesDone(); client.WritesDone();
proto::GetSumResponse response; proto::GetSumResponse response;
std::list<int> expected_responses = {0, 0, 1, 1, 3, 3}; std::list<int> expected_responses = {0, 0, 1, 1, 3, 3};
while (reader_writer->Read(&response)) { while (client.Read(&response)) {
EXPECT_EQ(expected_responses.front(), response.output()); EXPECT_EQ(expected_responses.front(), response.output());
expected_responses.pop_front(); expected_responses.pop_front();
} }
EXPECT_TRUE(expected_responses.empty()); EXPECT_TRUE(expected_responses.empty());
EXPECT_TRUE(reader_writer->Finish().ok()); EXPECT_TRUE(client.Finish().ok());
server_->Shutdown(); server_->Shutdown();
} }
@ -232,10 +225,6 @@ TEST_F(ServerTest, WriteFromOtherThread) {
cartographer::common::make_unique<MathServerContext>()); cartographer::common::make_unique<MathServerContext>());
server_->Start(); server_->Start();
proto::GetEchoResponse result;
proto::GetEchoRequest request;
request.set_input(13);
Server* server = server_.get(); Server* server = server_.get();
std::thread response_thread([server]() { std::thread response_thread([server]() {
std::future<EchoResponder> responder_future = std::future<EchoResponder> responder_future =
@ -245,10 +234,12 @@ TEST_F(ServerTest, WriteFromOtherThread) {
CHECK(responder()); CHECK(responder());
}); });
grpc::Status status = stub_->GetEcho(&client_context_, request, &result); Client<GetEchoHandler> client(client_channel_);
proto::GetEchoRequest request;
request.set_input(13);
EXPECT_TRUE(client.Write(request));
response_thread.join(); response_thread.join();
EXPECT_TRUE(status.ok()); EXPECT_EQ(client.response().output(), 13);
EXPECT_EQ(result.output(), 13);
server_->Shutdown(); server_->Shutdown();
} }
@ -256,17 +247,18 @@ TEST_F(ServerTest, WriteFromOtherThread) {
TEST_F(ServerTest, ProcessServerStreamingRpcTest) { TEST_F(ServerTest, ProcessServerStreamingRpcTest) {
server_->Start(); server_->Start();
Client<GetSequenceHandler> client(client_channel_);
proto::GetSequenceRequest request; proto::GetSequenceRequest request;
request.set_input(12); request.set_input(12);
auto reader = stub_->GetSequence(&client_context_, request);
client.Write(request);
proto::GetSequenceResponse response; proto::GetSequenceResponse response;
for (int i = 0; i < 12; ++i) { for (int i = 0; i < 12; ++i) {
EXPECT_TRUE(reader->Read(&response)); EXPECT_TRUE(client.Read(&response));
EXPECT_EQ(response.output(), i); EXPECT_EQ(response.output(), i);
} }
EXPECT_FALSE(reader->Read(&response)); EXPECT_FALSE(client.Read(&response));
EXPECT_TRUE(reader->Finish().ok()); EXPECT_TRUE(client.Finish().ok());
server_->Shutdown(); server_->Shutdown();
} }