/* * Copyright 2018 The Cartographer Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H #define CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H #include "cartographer_grpc/framework/retry.h" #include "cartographer_grpc/framework/rpc_handler_interface.h" #include "cartographer_grpc/framework/type_traits.h" #include "glog/logging.h" #include "grpc++/grpc++.h" #include "grpc++/impl/codegen/client_unary_call.h" #include "grpc++/impl/codegen/sync_stream.h" namespace cartographer_grpc { namespace framework { template class Client { public: Client(std::shared_ptr channel, RetryStrategy retry_strategy) : channel_(channel), client_context_( cartographer::common::make_unique()), rpc_method_name_( RpcHandlerInterface::Instantiate()->method_name()), rpc_method_(rpc_method_name_.c_str(), RpcType::value, channel_), retry_strategy_(retry_strategy) { CHECK(!retry_strategy || rpc_method_.method_type() == ::grpc::internal::RpcMethod::NORMAL_RPC) << "Retry is currently only support for NORMAL_RPC."; } Client(std::shared_ptr channel) : channel_(channel), client_context_( cartographer::common::make_unique()), rpc_method_name_( RpcHandlerInterface::Instantiate()->method_name()), rpc_method_(rpc_method_name_.c_str(), RpcType::value, channel_) {} bool Read(typename RpcHandlerType::ResponseType *response) { switch (rpc_method_.method_type()) { case grpc::internal::RpcMethod::BIDI_STREAMING: InstantiateClientReaderWriterIfNeeded(); return client_reader_writer_->Read(response); case grpc::internal::RpcMethod::SERVER_STREAMING: CHECK(client_reader_); return client_reader_->Read(response); default: LOG(FATAL) << "Not implemented."; } } bool Write(const typename RpcHandlerType::RequestType &request) { return RetryWithStrategy( retry_strategy_, std::bind(&Client::WriteImpl, this, request), std::bind(&Client::Reset, this)); } bool WritesDone() { switch (rpc_method_.method_type()) { case grpc::internal::RpcMethod::CLIENT_STREAMING: InstantiateClientWriterIfNeeded(); return client_writer_->WritesDone(); case grpc::internal::RpcMethod::BIDI_STREAMING: InstantiateClientReaderWriterIfNeeded(); return client_reader_writer_->WritesDone(); default: LOG(FATAL) << "Not implemented."; } } grpc::Status Finish() { switch (rpc_method_.method_type()) { case grpc::internal::RpcMethod::CLIENT_STREAMING: InstantiateClientWriterIfNeeded(); return client_writer_->Finish(); case grpc::internal::RpcMethod::BIDI_STREAMING: InstantiateClientReaderWriterIfNeeded(); return client_reader_writer_->Finish(); case grpc::internal::RpcMethod::SERVER_STREAMING: CHECK(client_reader_); return client_reader_->Finish(); default: LOG(FATAL) << "Not implemented."; } } const typename RpcHandlerType::ResponseType &response() { CHECK(rpc_method_.method_type() == grpc::internal::RpcMethod::NORMAL_RPC || rpc_method_.method_type() == grpc::internal::RpcMethod::CLIENT_STREAMING); return response_; } private: void Reset() { client_context_ = cartographer::common::make_unique(); } bool WriteImpl(const typename RpcHandlerType::RequestType &request) { switch (rpc_method_.method_type()) { case grpc::internal::RpcMethod::NORMAL_RPC: return MakeBlockingUnaryCall(request, &response_).ok(); case grpc::internal::RpcMethod::CLIENT_STREAMING: InstantiateClientWriterIfNeeded(); return client_writer_->Write(request); case grpc::internal::RpcMethod::BIDI_STREAMING: InstantiateClientReaderWriterIfNeeded(); return client_reader_writer_->Write(request); case grpc::internal::RpcMethod::SERVER_STREAMING: InstantiateClientReader(request); return true; } LOG(FATAL) << "Not reached."; } void InstantiateClientWriterIfNeeded() { CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::CLIENT_STREAMING); if (!client_writer_) { client_writer_.reset( grpc::internal:: ClientWriterFactory::Create( channel_.get(), rpc_method_, client_context_.get(), &response_)); } } void InstantiateClientReaderWriterIfNeeded() { CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::BIDI_STREAMING); if (!client_reader_writer_) { client_reader_writer_.reset( grpc::internal::ClientReaderWriterFactory< typename RpcHandlerType::RequestType, typename RpcHandlerType::ResponseType>::Create(channel_.get(), rpc_method_, client_context_ .get())); } } void InstantiateClientReader( const typename RpcHandlerType::RequestType &request) { CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::SERVER_STREAMING); client_reader_.reset( grpc::internal:: ClientReaderFactory::Create( channel_.get(), rpc_method_, client_context_.get(), request)); } grpc::Status MakeBlockingUnaryCall( const typename RpcHandlerType::RequestType &request, typename RpcHandlerType::ResponseType *response) { CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::NORMAL_RPC); return ::grpc::internal::BlockingUnaryCall( channel_.get(), rpc_method_, client_context_.get(), request, response); } std::shared_ptr channel_; std::unique_ptr client_context_; const std::string rpc_method_name_; const ::grpc::internal::RpcMethod rpc_method_; std::unique_ptr> client_writer_; std::unique_ptr< grpc::ClientReaderWriter> client_reader_writer_; std::unique_ptr> client_reader_; typename RpcHandlerType::ResponseType response_; RetryStrategy retry_strategy_; }; } // namespace framework } // namespace cartographer_grpc #endif // CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H