Christoph Schütte 2017-11-29 14:05:31 +01:00 committed by Wally B. Feed
parent 3a46804393
commit 999820d845
10 changed files with 118 additions and 10 deletions

View File

@ -0,0 +1,61 @@
/*
* Copyright 2017 The Cartographer Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H
#include "cartographer/common/mutex.h"
#include "glog/logging.h"
namespace cartographer_grpc {
namespace framework {
// Implementations of this class allow RPC handlers to share state among one
// another. Using Server::SetExecutionContext(...) a server-wide
// 'ExecutionContext' can be specified. This 'ExecutionContext' can be retrieved
// by all implementations of 'RpcHandler' by calling
// 'RpcHandler::GetContext<MyContext>()'.
class ExecutionContext {
public:
// This non-movable, non-copyable class is used to broker access from various
// RPC handlers to the shared 'ExecutionContext'. Handles automatically lock
// the context they point to.
template <typename ContextType>
class Synchronized {
public:
ContextType* operator->() {
return static_cast<ContextType*>(execution_context_);
}
Synchronized(cartographer::common::Mutex* lock,
ExecutionContext* execution_context)
: locker_(lock), execution_context_(execution_context) {}
Synchronized(const Synchronized&) = delete;
Synchronized(Synchronized&&) = delete;
private:
cartographer::common::MutexLocker locker_;
ExecutionContext* execution_context_;
};
cartographer::common::Mutex* lock() { return &lock_; }
private:
cartographer::common::Mutex lock_;
};
} // namespace framework
} // namespace cartographer_grpc
#endif // CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H

View File

@ -25,16 +25,18 @@ namespace framework {
Rpc::Rpc(int method_index, Rpc::Rpc(int method_index,
::grpc::ServerCompletionQueue* server_completion_queue, ::grpc::ServerCompletionQueue* server_completion_queue,
ExecutionContext* execution_context,
const RpcHandlerInfo& rpc_handler_info, Service* service) const RpcHandlerInfo& rpc_handler_info, Service* service)
: method_index_(method_index), : method_index_(method_index),
server_completion_queue_(server_completion_queue), server_completion_queue_(server_completion_queue),
execution_context_(execution_context),
rpc_handler_info_(rpc_handler_info), rpc_handler_info_(rpc_handler_info),
service_(service), service_(service),
new_connection_event_{Event::NEW_CONNECTION, this, false}, new_connection_event_{Event::NEW_CONNECTION, this, false},
read_event_{Event::READ, this, false}, read_event_{Event::READ, this, false},
write_event_{Event::WRITE, this, false}, write_event_{Event::WRITE, this, false},
done_event_{Event::DONE, this, false}, done_event_{Event::DONE, this, false},
handler_(rpc_handler_info_.rpc_handler_factory(this)) { handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) {
InitializeReadersAndWriters(rpc_handler_info_.rpc_type); InitializeReadersAndWriters(rpc_handler_info_.rpc_type);
// Initialize the prototypical request and response messages. // Initialize the prototypical request and response messages.
@ -48,7 +50,8 @@ Rpc::Rpc(int method_index,
std::unique_ptr<Rpc> Rpc::Clone() { std::unique_ptr<Rpc> Rpc::Clone() {
return cartographer::common::make_unique<Rpc>( return cartographer::common::make_unique<Rpc>(
method_index_, server_completion_queue_, rpc_handler_info_, service_); method_index_, server_completion_queue_, execution_context_,
rpc_handler_info_, service_);
} }
void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); }

View File

@ -21,6 +21,7 @@
#include <unordered_set> #include <unordered_set>
#include "cartographer/common/mutex.h" #include "cartographer/common/mutex.h"
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/rpc_handler_interface.h" #include "cartographer_grpc/framework/rpc_handler_interface.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
@ -47,6 +48,7 @@ class Rpc {
}; };
Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue,
ExecutionContext* execution_context,
const RpcHandlerInfo& rpc_handler_info, Service* service); const RpcHandlerInfo& rpc_handler_info, Service* service);
std::unique_ptr<Rpc> Clone(); std::unique_ptr<Rpc> Clone();
void OnRequest(); void OnRequest();
@ -69,6 +71,7 @@ class Rpc {
int method_index_; int method_index_;
::grpc::ServerCompletionQueue* server_completion_queue_; ::grpc::ServerCompletionQueue* server_completion_queue_;
ExecutionContext* execution_context_;
RpcHandlerInfo rpc_handler_info_; RpcHandlerInfo rpc_handler_info_;
Service* service_; Service* service_;
::grpc::ServerContext server_context_; ::grpc::ServerContext server_context_;

View File

@ -17,6 +17,7 @@
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc.h"
#include "cartographer_grpc/framework/rpc_handler_interface.h" #include "cartographer_grpc/framework/rpc_handler_interface.h"
#include "cartographer_grpc/framework/type_traits.h" #include "cartographer_grpc/framework/type_traits.h"
@ -35,6 +36,9 @@ class RpcHandler : public RpcHandlerInterface {
using RequestType = StripStream<Incoming>; using RequestType = StripStream<Incoming>;
using ResponseType = StripStream<Outgoing>; using ResponseType = StripStream<Outgoing>;
void SetExecutionContext(ExecutionContext* execution_context) {
execution_context_ = execution_context;
}
void SetRpc(Rpc* rpc) override { rpc_ = rpc; } void SetRpc(Rpc* rpc) override { rpc_ = rpc; }
void OnRequestInternal(const ::google::protobuf::Message* request) override { void OnRequestInternal(const ::google::protobuf::Message* request) override {
DCHECK(dynamic_cast<const RequestType*>(request)); DCHECK(dynamic_cast<const RequestType*>(request));
@ -44,9 +48,14 @@ class RpcHandler : public RpcHandlerInterface {
void Send(std::unique_ptr<ResponseType> response) { void Send(std::unique_ptr<ResponseType> response) {
rpc_->Write(std::move(response)); rpc_->Write(std::move(response));
} }
template <typename T>
ExecutionContext::Synchronized<T> GetContext() {
return {execution_context_->lock(), execution_context_};
}
private: private:
Rpc* rpc_; Rpc* rpc_;
ExecutionContext* execution_context_;
}; };
} // namespace framework } // namespace framework

View File

@ -17,6 +17,7 @@
#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H
#include "cartographer_grpc/framework/execution_context.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
@ -27,14 +28,15 @@ class Rpc;
class RpcHandlerInterface { class RpcHandlerInterface {
public: public:
virtual ~RpcHandlerInterface() = default; virtual ~RpcHandlerInterface() = default;
virtual void SetExecutionContext(ExecutionContext* execution_context) = 0;
virtual void SetRpc(Rpc* rpc) = 0; virtual void SetRpc(Rpc* rpc) = 0;
virtual void OnRequestInternal( virtual void OnRequestInternal(
const ::google::protobuf::Message* request) = 0; const ::google::protobuf::Message* request) = 0;
virtual void OnReadsDone() = 0; virtual void OnReadsDone() = 0;
}; };
using RpcHandlerFactory = using RpcHandlerFactory = std::function<std::unique_ptr<RpcHandlerInterface>(
std::function<std::unique_ptr<RpcHandlerInterface>(Rpc*)>; Rpc*, ExecutionContext*)>;
struct RpcHandlerInfo { struct RpcHandlerInfo {
const google::protobuf::Descriptor* request_descriptor; const google::protobuf::Descriptor* request_descriptor;

View File

@ -77,7 +77,8 @@ void Server::Start() {
// Start serving all services on all completion queues. // Start serving all services on all completion queues.
for (auto& service : services_) { for (auto& service : services_) {
service.second.StartServing(completion_queue_threads_); service.second.StartServing(completion_queue_threads_,
execution_context_.get());
} }
// Start threads to process all completion queues. // Start threads to process all completion queues.
@ -107,5 +108,13 @@ void Server::Shutdown() {
LOG(INFO) << "Shutdown complete."; LOG(INFO) << "Shutdown complete.";
} }
void Server::SetExecutionContext(
std::unique_ptr<ExecutionContext> execution_context) {
// After the server has been started the 'ExecutionHandle' cannot be changed
// anymore.
CHECK(!server_);
execution_context_ = std::move(execution_context);
}
} // namespace framework } // namespace framework
} // namespace cartographer_grpc } // namespace cartographer_grpc

View File

@ -24,6 +24,7 @@
#include "cartographer/common/make_unique.h" #include "cartographer/common/make_unique.h"
#include "cartographer_grpc/framework/completion_queue_thread.h" #include "cartographer_grpc/framework/completion_queue_thread.h"
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/rpc_handler.h"
#include "cartographer_grpc/framework/service.h" #include "cartographer_grpc/framework/service.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
@ -57,10 +58,11 @@ class Server {
RpcHandlerInfo{ RpcHandlerInfo{
RpcHandlerType::RequestType::default_instance().GetDescriptor(), RpcHandlerType::RequestType::default_instance().GetDescriptor(),
RpcHandlerType::ResponseType::default_instance().GetDescriptor(), RpcHandlerType::ResponseType::default_instance().GetDescriptor(),
[](Rpc* const rpc) { [](Rpc* const rpc, ExecutionContext* const execution_context) {
std::unique_ptr<RpcHandlerInterface> rpc_handler = std::unique_ptr<RpcHandlerInterface> rpc_handler =
cartographer::common::make_unique<RpcHandlerType>(); cartographer::common::make_unique<RpcHandlerType>();
rpc_handler->SetRpc(rpc); rpc_handler->SetRpc(rpc);
rpc_handler->SetExecutionContext(execution_context);
return rpc_handler; return rpc_handler;
}, },
RpcType<typename RpcHandlerType::IncomingType, RpcType<typename RpcHandlerType::IncomingType,
@ -81,6 +83,9 @@ class Server {
// Shuts down the server and all of its services. // Shuts down the server and all of its services.
void Shutdown(); void Shutdown();
// Sets the server-wide context object shared between RPC handlers.
void SetExecutionContext(std::unique_ptr<ExecutionContext> execution_context);
private: private:
Server(const Options& options); Server(const Options& options);
Server(const Server&) = delete; Server(const Server&) = delete;
@ -104,6 +109,10 @@ class Server {
// Map of service names to services. // Map of service names to services.
std::map<std::string, Service> services_; std::map<std::string, Service> services_;
// A context object that is shared between all implementations of
// 'RpcHandler'.
std::unique_ptr<ExecutionContext> execution_context_;
}; };
} // namespace framework } // namespace framework

View File

@ -16,6 +16,7 @@
#include "cartographer_grpc/framework/server.h" #include "cartographer_grpc/framework/server.h"
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" #include "cartographer_grpc/framework/proto/math_service.grpc.pb.h"
#include "cartographer_grpc/framework/proto/math_service.pb.h" #include "cartographer_grpc/framework/proto/math_service.pb.h"
#include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/rpc_handler.h"
@ -27,10 +28,16 @@ namespace cartographer_grpc {
namespace framework { namespace framework {
namespace { namespace {
class MathServerContext : public ExecutionContext {
public:
int additional_increment() { return 10; }
};
class GetServerOptionsHandler class GetServerOptionsHandler
: public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> { : public RpcHandler<Stream<proto::GetSumRequest>, proto::GetSumResponse> {
public: public:
void OnRequest(const proto::GetSumRequest& request) override { void OnRequest(const proto::GetSumRequest& request) override {
sum_ += GetContext<MathServerContext>()->additional_increment();
sum_ += request.input(); sum_ += request.input();
} }
@ -70,6 +77,8 @@ TEST_F(ServerTest, StartAndStopServerTest) {
} }
TEST_F(ServerTest, ProcessRpcStreamTest) { TEST_F(ServerTest, ProcessRpcStreamTest) {
server_->SetExecutionContext(
cartographer::common::make_unique<MathServerContext>());
server_->Start(); server_->Start();
auto channel = auto channel =
@ -87,7 +96,7 @@ TEST_F(ServerTest, ProcessRpcStreamTest) {
writer->WritesDone(); writer->WritesDone();
grpc::Status status = writer->Finish(); grpc::Status status = writer->Finish();
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
EXPECT_EQ(result.output(), 3); EXPECT_EQ(result.output(), 33);
server_->Shutdown(); server_->Shutdown();
} }

View File

@ -37,12 +37,13 @@ Service::Service(const std::string& service_name,
} }
void Service::StartServing( void Service::StartServing(
std::vector<CompletionQueueThread>& completion_queue_threads) { std::vector<CompletionQueueThread>& completion_queue_threads,
ExecutionContext* execution_context) {
int i = 0; int i = 0;
for (const auto& rpc_handler_info : rpc_handler_infos_) { for (const auto& rpc_handler_info : rpc_handler_infos_) {
for (auto& completion_queue_thread : completion_queue_threads) { for (auto& completion_queue_thread : completion_queue_threads) {
Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>( Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique<Rpc>(
i, completion_queue_thread.completion_queue(), i, completion_queue_thread.completion_queue(), execution_context,
rpc_handler_info.second, this)); rpc_handler_info.second, this));
rpc->RequestNextMethodInvocation(); rpc->RequestNextMethodInvocation();
} }

View File

@ -18,6 +18,7 @@
#define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H #define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H
#include "cartographer_grpc/framework/completion_queue_thread.h" #include "cartographer_grpc/framework/completion_queue_thread.h"
#include "cartographer_grpc/framework/execution_context.h"
#include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc.h"
#include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/rpc_handler.h"
#include "grpc++/impl/codegen/service_type.h" #include "grpc++/impl/codegen/service_type.h"
@ -35,7 +36,8 @@ class Service : public ::grpc::Service {
Service(const std::string& service_name, Service(const std::string& service_name,
const std::map<std::string, RpcHandlerInfo>& rpc_handlers); const std::map<std::string, RpcHandlerInfo>& rpc_handlers);
void StartServing(std::vector<CompletionQueueThread>& completion_queues); void StartServing(std::vector<CompletionQueueThread>& completion_queues,
ExecutionContext* execution_context);
void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok); void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok);
void StopServing(); void StopServing();