diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index 226de42..f44bd4d 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -62,6 +62,7 @@ class Server { std::string method_name; std::tie(service_full_name, method_name) = ParseMethodFullName(method_full_name); + CheckHandlerCompatibility(service_full_name, method_name); rpc_handlers_[service_full_name].emplace( method_name, RpcHandlerInfo{ @@ -91,6 +92,44 @@ class Server { std::string /* method_name */> ParseMethodFullName(const std::string& method_full_name); + template + void CheckHandlerCompatibility(const std::string& service_full_name, + const std::string& method_name) { + const auto* pool = google::protobuf::DescriptorPool::generated_pool(); + const auto* service = pool->FindServiceByName(service_full_name); + CHECK(service) << "Unknown service " << service_full_name; + const auto* method_descriptor = service->FindMethodByName(method_name); + CHECK(method_descriptor) << "Unknown method " << method_name + << " in service " << service_full_name; + const auto* request_type = method_descriptor->input_type(); + CHECK_EQ(RpcHandlerType::RequestType::default_instance().GetDescriptor(), + request_type); + const auto* response_type = method_descriptor->output_type(); + CHECK_EQ(RpcHandlerType::ResponseType::default_instance().GetDescriptor(), + response_type); + const auto rpc_type = + RpcType::value; + switch (rpc_type) { + case ::grpc::internal::RpcMethod::NORMAL_RPC: + CHECK(!method_descriptor->client_streaming()); + CHECK(!method_descriptor->server_streaming()); + break; + case ::grpc::internal::RpcMethod::CLIENT_STREAMING: + CHECK(method_descriptor->client_streaming()); + CHECK(!method_descriptor->server_streaming()); + break; + case ::grpc::internal::RpcMethod::SERVER_STREAMING: + CHECK(!method_descriptor->client_streaming()); + CHECK(method_descriptor->server_streaming()); + break; + case ::grpc::internal::RpcMethod::BIDI_STREAMING: + CHECK(method_descriptor->client_streaming()); + CHECK(method_descriptor->server_streaming()); + break; + } + } + Options options_; std::map rpc_handlers_; };