diff --git a/cartographer/mapping/id.h b/cartographer/mapping/id.h index 1f772ef..b7dadaa 100644 --- a/cartographer/mapping/id.h +++ b/cartographer/mapping/id.h @@ -18,6 +18,7 @@ #define CARTOGRAPHER_MAPPING_ID_H_ #include +#include #include #include #include @@ -28,6 +29,7 @@ #include "cartographer/common/make_unique.h" #include "cartographer/common/port.h" +#include "cartographer/common/time.h" #include "glog/logging.h" namespace cartographer { @@ -322,6 +324,35 @@ class MapById { bool empty() const { return begin() == end(); } + // Returns an iterator to the the first element in the container belonging to + // trajectory 'trajectory_id' whose time is not considered to go before + // 'time', or EndOfTrajectory(trajectory_id) if all keys are considered to go + // before 'time'. + ConstIterator lower_bound(const int trajectory_id, const common::Time& time) { + if (SizeOfTrajectoryOrZero(trajectory_id) == 0) { + return EndOfTrajectory(trajectory_id); + } + + const std::map& trajectory = + trajectories_.at(trajectory_id).data_; + if (std::prev(trajectory.end())->second.time() < time) { + return EndOfTrajectory(trajectory_id); + } + auto left = trajectory.begin(); + auto right = std::prev(trajectory.end()); + while (left != right) { + const int middle = left->first + (right->first - left->first) / 2; + const auto lower_bound_middle = trajectory.lower_bound(middle); + if (lower_bound_middle->second.time() < time) { + left = std::next(lower_bound_middle); + } else { + right = lower_bound_middle; + } + } + + return ConstIterator(*this, IdType{trajectory_id, left->first}); + } + private: struct MapByIndex { bool can_append_ = true; diff --git a/cartographer/mapping/id_test.cc b/cartographer/mapping/id_test.cc index 8a6b23a..fce3fc9 100644 --- a/cartographer/mapping/id_test.cc +++ b/cartographer/mapping/id_test.cc @@ -16,16 +16,35 @@ #include "cartographer/mapping/id.h" +#include #include #include +#include #include +#include "cartographer/common/time.h" #include "gtest/gtest.h" namespace cartographer { namespace mapping { namespace { +common::Time CreateTime(const int milliseconds) { + return common::Time(common::FromMilliseconds(milliseconds)); +} + +class Data { + public: + Data(int milliseconds) : time_(CreateTime(milliseconds)) {} + + const common::Time& time() const { + return time_; + } + + private: + const common::Time time_; +}; + template static MapById CreateTestMapById() { MapById map_by_id; @@ -145,6 +164,65 @@ TEST(IdTest, FindSubmapId) { EXPECT_TRUE(map_by_id.find(SubmapId{42, 3}) == map_by_id.end()); } +TEST(IdTest, LowerBoundEdgeCases) { + MapById map_by_id; + map_by_id.Append(0, Data(1)); + map_by_id.Append(2, Data(2)); + CHECK(map_by_id.lower_bound(1, CreateTime(10)) == + map_by_id.EndOfTrajectory(1)); + CHECK(map_by_id.lower_bound(2, CreateTime(3)) == + map_by_id.EndOfTrajectory(2)); + CHECK(map_by_id.lower_bound(2, CreateTime(1)) == + map_by_id.BeginOfTrajectory(2)); +} + +TEST(IdTest, LowerBound) { + MapById map_by_id; + map_by_id.Append(0, Data(1)); + map_by_id.Append(0, Data(2)); + map_by_id.Append(0, Data(4)); + map_by_id.Append(0, Data(5)); + CHECK(map_by_id.lower_bound(0, CreateTime(3)) == + (MapById::ConstIterator(map_by_id, SubmapId{0, 2}))); + CHECK(map_by_id.lower_bound(0, CreateTime(2)) == + (MapById::ConstIterator(map_by_id, SubmapId{0, 1}))); + CHECK(map_by_id.lower_bound(0, CreateTime(4)) == + (MapById::ConstIterator(map_by_id, SubmapId{0, 2}))); +} + +TEST(IdTest, LowerBoundFuzz) { + constexpr int kMaxTimeIncrement = 20; + constexpr int kMaxNumNodes = 20; + constexpr int kNumTests = 100; + constexpr int kTrajectoryId = 1; + + std::mt19937 rng; + std::uniform_int_distribution dt_dist(1, kMaxTimeIncrement); + std::uniform_int_distribution N_dist(1, kMaxNumNodes); + + for (int i = 0; i < kNumTests; ++i) { + const int N = N_dist(rng); + int t = 0; + MapById map_by_id; + for (int j = 0; j < N; ++j) { + t = t + dt_dist(rng); + map_by_id.Append(kTrajectoryId, Data(t)); + } + std::uniform_int_distribution t0_dist(1, N * kMaxTimeIncrement + 1); + int t0 = t0_dist(rng); + auto it = map_by_id.lower_bound(kTrajectoryId, CreateTime(t0)); + + auto ground_truth = std::lower_bound( + map_by_id.BeginOfTrajectory(kTrajectoryId), + map_by_id.EndOfTrajectory(kTrajectoryId), CreateTime(t0), + [](MapById::IdDataReference a, const common::Time& t) { + return a.data.time() < t; + }); + + CHECK(ground_truth == it); + } +} + } // namespace } // namespace mapping } // namespace cartographer