diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index f625efb..901cb90 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -35,7 +35,28 @@ class RpcHandler : public RpcHandlerInterface { using OutgoingType = Outgoing; using RequestType = StripStream; using ResponseType = StripStream; - using Writer = std::function)>; + + class Writer { + public: + explicit Writer(std::weak_ptr rpc) : rpc_(std::move(rpc)) {} + bool Write(std::unique_ptr message) const { + if (auto rpc = rpc_.lock()) { + rpc->Write(std::move(message)); + return true; + } + return false; + } + bool WritesDone() const { + if (auto rpc = rpc_.lock()) { + rpc->Finish(::grpc::Status::OK); + return true; + } + return false; + } + + private: + const std::weak_ptr rpc_; + }; void SetExecutionContext(ExecutionContext* execution_context) { execution_context_ = execution_context; @@ -58,16 +79,7 @@ class RpcHandler : public RpcHandlerInterface { T* GetUnsynchronizedContext() { return dynamic_cast(execution_context_); } - Writer GetWriter() { - std::weak_ptr weak_ptr_rpc = rpc_->GetWeakPtr(); - return [weak_ptr_rpc](std::unique_ptr message) { - if (auto rpc = weak_ptr_rpc.lock()) { - rpc->Write(std::move(message)); - return true; - } - return false; - }; - } + Writer GetWriter() { return Writer(rpc_->GetWeakPtr()); } private: Rpc* rpc_; diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index d0c7d42..884ef63 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -96,7 +96,7 @@ class GetEchoHandler auto response = cartographer::common::make_unique(); response->set_output(value); - return writer(std::move(response)); + return writer.Write(std::move(response)); }); } }; diff --git a/cartographer_grpc/handlers/finish_trajectory_handler.h b/cartographer_grpc/handlers/finish_trajectory_handler.h index 30ff31f..29fe247 100644 --- a/cartographer_grpc/handlers/finish_trajectory_handler.h +++ b/cartographer_grpc/handlers/finish_trajectory_handler.h @@ -34,6 +34,8 @@ class FinishTrajectoryHandler GetContext() ->map_builder() .FinishTrajectory(request.trajectory_id()); + GetUnsynchronizedContext() + ->NotifyFinishTrajectory(request.trajectory_id()); Send(std::move( cartographer::common::make_unique())); } diff --git a/cartographer_grpc/handlers/receive_local_slam_results_handler.h b/cartographer_grpc/handlers/receive_local_slam_results_handler.h index 03a431e..534c39c 100644 --- a/cartographer_grpc/handlers/receive_local_slam_results_handler.h +++ b/cartographer_grpc/handlers/receive_local_slam_results_handler.h @@ -39,7 +39,14 @@ class ReceiveLocalSlamResultsHandler request.trajectory_id(), [writer](std::unique_ptr local_slam_result) { - writer(GenerateResponse(std::move(local_slam_result))); + if (local_slam_result) { + writer.Write( + GenerateResponse(std::move(local_slam_result))); + } else { + // Callback with 'nullptr' signals that the trajectory + // finished. + writer.WritesDone(); + } }); subscription_id_ = diff --git a/cartographer_grpc/map_builder_server.cc b/cartographer_grpc/map_builder_server.cc index 65499b0..c9e3255 100644 --- a/cartographer_grpc/map_builder_server.cc +++ b/cartographer_grpc/map_builder_server.cc @@ -83,6 +83,11 @@ void MapBuilderServer::MapBuilderContext::UnsubscribeLocalSlamResults( map_builder_server_->UnsubscribeLocalSlamResults(subscription_id); } +void MapBuilderServer::MapBuilderContext::NotifyFinishTrajectory( + int trajectory_id) { + map_builder_server_->NotifyFinishTrajectory(trajectory_id); +} + MapBuilderServer::MapBuilderServer( const proto::MapBuilderServerOptions& map_builder_server_options, std::unique_ptr map_builder) @@ -193,4 +198,13 @@ void MapBuilderServer::UnsubscribeLocalSlamResults( 1u); } +void MapBuilderServer::NotifyFinishTrajectory(int trajectory_id) { + cartographer::common::MutexLocker locker(&local_slam_subscriptions_lock_); + for (auto& entry : local_slam_subscriptions_[trajectory_id]) { + LocalSlamSubscriptionCallback callback = entry.second; + // 'nullptr' signals subscribers that the trajectory finished. + callback(nullptr); + } +} + } // namespace cartographer_grpc diff --git a/cartographer_grpc/map_builder_server.h b/cartographer_grpc/map_builder_server.h index 08e17ba..04a49bf 100644 --- a/cartographer_grpc/map_builder_server.h +++ b/cartographer_grpc/map_builder_server.h @@ -37,6 +37,7 @@ class MapBuilderServer { std::shared_ptr range_data; std::unique_ptr node_id; }; + // Calling with 'nullptr' signals subscribers that the subscription has ended. using LocalSlamSubscriptionCallback = std::function)>; struct SensorData { @@ -60,6 +61,7 @@ class MapBuilderServer { SubscriptionId SubscribeLocalSlamResults( int trajectory_id, LocalSlamSubscriptionCallback callback); void UnsubscribeLocalSlamResults(const SubscriptionId& subscription_id); + void NotifyFinishTrajectory(int trajectory_id); template void EnqueueSensorData(int trajectory_id, const std::string& sensor_id, @@ -105,6 +107,7 @@ class MapBuilderServer { SubscriptionId SubscribeLocalSlamResults( int trajectory_id, LocalSlamSubscriptionCallback callback); void UnsubscribeLocalSlamResults(const SubscriptionId& subscription_id); + void NotifyFinishTrajectory(int trajectory_id); bool shutting_down_ = false; std::unique_ptr slam_thread_;