Implement RpcHandler::GetWriter and add unittest (#767)

master
Christoph Schütte 2017-12-18 21:27:03 +01:00 committed by Wally B. Feed
parent def442b9db
commit c79425cbb0
5 changed files with 69 additions and 0 deletions

View File

@ -32,9 +32,18 @@ message GetSquareResponse {
int32 output = 1;
}
message GetEchoRequest {
int32 input = 1;
}
message GetEchoResponse {
int32 output = 1;
}
// Provides information about the gRPC server.
service Math {
rpc GetSum(stream GetSumRequest) returns (GetSumResponse);
rpc GetSquare(GetSquareRequest) returns (GetSquareResponse);
rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse);
rpc GetEcho(GetEchoRequest) returns (GetEchoResponse);
}

View File

@ -283,6 +283,8 @@ bool Rpc::IsAnyEventPending() {
IsRpcEventPending(Rpc::Event::FINISH);
}
std::weak_ptr<Rpc> Rpc::GetWeakPtr() { return weak_ptr_factory_(this); }
ActiveRpcs::ActiveRpcs() : lock_() {}
void Rpc::InitializeReadersAndWriters(

View File

@ -73,6 +73,7 @@ class Rpc {
bool IsAnyEventPending();
void SetEventQueue(EventQueue* event_queue) { event_queue_ = event_queue; }
EventQueue* event_queue() { return event_queue_; }
std::weak_ptr<Rpc> GetWeakPtr();
private:
struct SendItem {

View File

@ -35,6 +35,7 @@ class RpcHandler : public RpcHandlerInterface {
using OutgoingType = Outgoing;
using RequestType = StripStream<Incoming>;
using ResponseType = StripStream<Outgoing>;
using Writer = std::function<bool(std::unique_ptr<ResponseType>)>;
void SetExecutionContext(ExecutionContext* execution_context) {
execution_context_ = execution_context;
@ -57,6 +58,16 @@ class RpcHandler : public RpcHandlerInterface {
T* GetUnsynchronizedContext() {
return dynamic_cast<T*>(execution_context_);
}
Writer GetWriter() {
std::weak_ptr<Rpc> weak_ptr_rpc = rpc_->GetWeakPtr();
return [weak_ptr_rpc](std::unique_ptr<ResponseType> message) {
if (auto rpc = weak_ptr_rpc.lock()) {
rpc->Write(std::move(message));
return true;
}
return false;
};
}
private:
Rpc* rpc_;

View File

@ -16,6 +16,8 @@
#include "cartographer_grpc/framework/server.h"
#include <future>
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/proto/math_service.grpc.pb.h"
#include "cartographer_grpc/framework/proto/math_service.pb.h"
@ -28,9 +30,11 @@ namespace cartographer_grpc {
namespace framework {
namespace {
using EchoResponder = std::function<bool()>;
class MathServerContext : public ExecutionContext {
public:
int additional_increment() { return 10; }
std::promise<EchoResponder> echo_responder;
};
class GetSumHandler
@ -82,6 +86,21 @@ class GetSquareHandler
}
};
class GetEchoHandler
: public RpcHandler<proto::GetEchoRequest, proto::GetEchoResponse> {
void OnRequest(const proto::GetEchoRequest& request) override {
int value = request.input();
Writer writer = GetWriter();
GetContext<MathServerContext>()->echo_responder.set_value(
[writer, value]() {
auto response =
cartographer::common::make_unique<proto::GetEchoResponse>();
response->set_output(value);
return writer(std::move(response));
});
}
};
// TODO(cschuet): Due to the hard-coded part these tests will become flaky when
// run in parallel. It would be nice to find a way to solve that. gRPC also
// allows to communicate over UNIX domain sockets.
@ -99,6 +118,7 @@ class ServerTest : public ::testing::Test {
server_builder.RegisterHandler<GetSquareHandler, proto::Math>("GetSquare");
server_builder.RegisterHandler<GetRunningSumHandler, proto::Math>(
"GetRunningSum");
server_builder.RegisterHandler<GetEchoHandler, proto::Math>("GetEcho");
server_ = server_builder.Build();
client_channel_ =
@ -173,6 +193,32 @@ TEST_F(ServerTest, ProcessBidiStreamingRpcTest) {
server_->Shutdown();
}
TEST_F(ServerTest, WriteFromOtherThread) {
server_->SetExecutionContext(
cartographer::common::make_unique<MathServerContext>());
server_->Start();
proto::GetEchoResponse result;
proto::GetEchoRequest request;
request.set_input(13);
Server* server = server_.get();
std::thread response_thread([server]() {
std::future<EchoResponder> responder_future =
server->GetContext<MathServerContext>()->echo_responder.get_future();
responder_future.wait();
auto responder = responder_future.get();
CHECK(responder());
});
grpc::Status status = stub_->GetEcho(&client_context_, request, &result);
response_thread.join();
EXPECT_TRUE(status.ok());
EXPECT_EQ(result.output(), 13);
server_->Shutdown();
}
} // namespace
} // namespace framework
} // namespace cartographer_grpc