diff --git a/cartographer_grpc/framework/proto/math_service.proto b/cartographer_grpc/framework/proto/math_service.proto index 474e1ea..64bb3a8 100644 --- a/cartographer_grpc/framework/proto/math_service.proto +++ b/cartographer_grpc/framework/proto/math_service.proto @@ -32,9 +32,18 @@ message GetSquareResponse { int32 output = 1; } +message GetEchoRequest { + int32 input = 1; +} + +message GetEchoResponse { + int32 output = 1; +} + // Provides information about the gRPC server. service Math { rpc GetSum(stream GetSumRequest) returns (GetSumResponse); rpc GetSquare(GetSquareRequest) returns (GetSquareResponse); rpc GetRunningSum(stream GetSumRequest) returns (stream GetSumResponse); + rpc GetEcho(GetEchoRequest) returns (GetEchoResponse); } diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index a945fd7..e539966 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -283,6 +283,8 @@ bool Rpc::IsAnyEventPending() { IsRpcEventPending(Rpc::Event::FINISH); } +std::weak_ptr Rpc::GetWeakPtr() { return weak_ptr_factory_(this); } + ActiveRpcs::ActiveRpcs() : lock_() {} void Rpc::InitializeReadersAndWriters( diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index 5b83909..95b66f3 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -73,6 +73,7 @@ class Rpc { bool IsAnyEventPending(); void SetEventQueue(EventQueue* event_queue) { event_queue_ = event_queue; } EventQueue* event_queue() { return event_queue_; } + std::weak_ptr GetWeakPtr(); private: struct SendItem { diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index ead2e7b..f625efb 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -35,6 +35,7 @@ class RpcHandler : public RpcHandlerInterface { using OutgoingType = Outgoing; using RequestType = StripStream; using ResponseType = StripStream; + using Writer = std::function)>; void SetExecutionContext(ExecutionContext* execution_context) { execution_context_ = execution_context; @@ -57,6 +58,16 @@ class RpcHandler : public RpcHandlerInterface { T* GetUnsynchronizedContext() { return dynamic_cast(execution_context_); } + Writer GetWriter() { + std::weak_ptr weak_ptr_rpc = rpc_->GetWeakPtr(); + return [weak_ptr_rpc](std::unique_ptr message) { + if (auto rpc = weak_ptr_rpc.lock()) { + rpc->Write(std::move(message)); + return true; + } + return false; + }; + } private: Rpc* rpc_; diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index c6873d6..4ec1f71 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -16,6 +16,8 @@ #include "cartographer_grpc/framework/server.h" +#include + #include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" #include "cartographer_grpc/framework/proto/math_service.pb.h" @@ -28,9 +30,11 @@ namespace cartographer_grpc { namespace framework { namespace { +using EchoResponder = std::function; class MathServerContext : public ExecutionContext { public: int additional_increment() { return 10; } + std::promise echo_responder; }; class GetSumHandler @@ -82,6 +86,21 @@ class GetSquareHandler } }; +class GetEchoHandler + : public RpcHandler { + void OnRequest(const proto::GetEchoRequest& request) override { + int value = request.input(); + Writer writer = GetWriter(); + GetContext()->echo_responder.set_value( + [writer, value]() { + auto response = + cartographer::common::make_unique(); + response->set_output(value); + return writer(std::move(response)); + }); + } +}; + // TODO(cschuet): Due to the hard-coded part these tests will become flaky when // run in parallel. It would be nice to find a way to solve that. gRPC also // allows to communicate over UNIX domain sockets. @@ -99,6 +118,7 @@ class ServerTest : public ::testing::Test { server_builder.RegisterHandler("GetSquare"); server_builder.RegisterHandler( "GetRunningSum"); + server_builder.RegisterHandler("GetEcho"); server_ = server_builder.Build(); client_channel_ = @@ -173,6 +193,32 @@ TEST_F(ServerTest, ProcessBidiStreamingRpcTest) { server_->Shutdown(); } +TEST_F(ServerTest, WriteFromOtherThread) { + server_->SetExecutionContext( + cartographer::common::make_unique()); + server_->Start(); + + proto::GetEchoResponse result; + proto::GetEchoRequest request; + request.set_input(13); + + Server* server = server_.get(); + std::thread response_thread([server]() { + std::future responder_future = + server->GetContext()->echo_responder.get_future(); + responder_future.wait(); + auto responder = responder_future.get(); + CHECK(responder()); + }); + + grpc::Status status = stub_->GetEcho(&client_context_, request, &result); + response_thread.join(); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(result.output(), 13); + + server_->Shutdown(); +} + } // namespace } // namespace framework } // namespace cartographer_grpc