small_gicp
sort_tbb.hpp
Go to the documentation of this file.
1 // SPDX-FileCopyrightText: Copyright 2024 Kenji Koide
2 // SPDX-License-Identifier: MIT
3 #pragma once
4 
5 #include <vector>
6 #include <algorithm>
7 #include <functional>
8 #include <tbb/tbb.h>
9 
10 namespace small_gicp {
11 
13 template <typename T>
15  std::vector<std::uint64_t> tile_buckets; //< Tiled buckets
16  std::vector<std::uint64_t> global_offsets; //< Global offsets
17  std::vector<T> sorted_buffer; //< Sorted objects
18 };
19 
31 template <typename T, typename KeyFunc, int bits = 8, int tile_size = 256>
32 void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_, RadixSortBuffers<T>& buffers) {
33  if (first_ == last_) {
34  return;
35  }
36 
37  auto first = first_;
38  auto last = last_;
39 
40  using Key = decltype(key_(*first));
41  static_assert(std::is_unsigned_v<Key>, "Key must be unsigned integral type");
42 
43  // Number of total radix sort steps.
44  constexpr int num_steps = (sizeof(Key) * 8 + bits - 1) / bits;
45 
46  constexpr int num_bins = 1 << bits;
47  const std::uint64_t N = std::distance(first, last);
48  const std::uint64_t num_tiles = (N + tile_size - 1) / tile_size;
49 
50  // Allocate buffers.
51  auto& tile_buckets = buffers.tile_buckets;
52  auto& global_offsets = buffers.global_offsets;
53  auto& sorted_buffer = buffers.sorted_buffer;
54  tile_buckets.resize(num_bins * num_tiles);
55  global_offsets.resize(num_bins);
56  sorted_buffer.resize(N);
57 
58  auto sorted = sorted_buffer.data();
59 
60  // Radix sort.
61  for (int step = 0; step < num_steps; step++) {
62  const auto key = [&](const auto& x) { return ((key_(x) >> (step * bits))) & ((1 << bits) - 1); };
63 
64  // Create per-tile histograms.
65  std::fill(tile_buckets.begin(), tile_buckets.end(), 0);
66  tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_tiles, 4), [&](const tbb::blocked_range<std::uint64_t>& r) {
67  for (std::uint64_t tile = r.begin(); tile < r.end(); tile++) {
68  std::uint64_t data_begin = tile * tile_size;
69  std::uint64_t data_end = std::min<std::uint64_t>((tile + 1) * tile_size, N);
70 
71  for (int i = data_begin; i < data_end; ++i) {
72  auto buckets = tile_buckets.data() + key(*(first + i)) * num_tiles;
73  ++buckets[tile];
74  }
75  }
76  });
77 
78  // Store the number of elements of the last tile, which will be overwritten by the next step, in global_offsets.
79  std::fill(global_offsets.begin(), global_offsets.end(), 0);
80  for (int i = 1; i < num_bins; i++) {
81  global_offsets[i] = tile_buckets[i * num_tiles - 1];
82  }
83 
84  // Calculate per-tile offsets.
85  tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_bins, 1), [&](const tbb::blocked_range<std::uint64_t>& r) {
86  for (std::uint64_t bin = r.begin(); bin < r.end(); bin++) {
87  auto buckets = tile_buckets.data() + bin * num_tiles;
88  std::uint64_t last = buckets[0];
89  buckets[0] = 0;
90 
91  for (std::uint64_t tile = 1; tile < num_tiles; tile++) {
92  std::uint64_t tmp = buckets[tile];
93  buckets[tile] = buckets[tile - 1] + last;
94  last = tmp;
95  }
96  }
97  });
98 
99  // Calculate global offsets for each sorting bin.
100  for (int i = 1; i < num_bins; i++) {
101  global_offsets[i] += global_offsets[i - 1] + tile_buckets[i * num_tiles - 1];
102  }
103 
104  // Sort elements.
105  tbb::parallel_for(tbb::blocked_range<std::uint64_t>(0, num_tiles, 8), [&](const tbb::blocked_range<std::uint64_t>& r) {
106  for (std::uint64_t tile = r.begin(); tile < r.end(); ++tile) {
107  std::uint64_t data_begin = tile * tile_size;
108  std::uint64_t data_end = std::min((tile + 1) * tile_size, static_cast<std::uint64_t>(N));
109 
110  for (std::uint64_t i = data_begin; i < data_end; ++i) {
111  const T x = *(first + i);
112  const int bin = key(x);
113  auto offset = tile_buckets.data() + bin * num_tiles + tile;
114  sorted[global_offsets[bin] + ((*offset)++)] = x;
115  }
116  }
117  });
118 
119  // Swap input and output buffers.
120  std::swap(first, sorted);
121  }
122 
123  // Copy the result to the original buffer.
124  if (num_steps % 2 == 1) {
125  std::copy(sorted_buffer.begin(), sorted_buffer.end(), first_);
126  }
127 }
128 
137 template <typename T, typename KeyFunc, int bits = 4, int tile_size = 256>
138 void radix_sort_tbb(T* first_, T* last_, const KeyFunc& key_) {
139  RadixSortBuffers<T> buffers;
140  radix_sort_tbb(first_, last_, key_, buffers);
141 }
142 
143 } // namespace small_gicp
Definition: flat_container.hpp:12
void radix_sort_tbb(T *first_, T *last_, const KeyFunc &key_, RadixSortBuffers< T > &buffers)
Radix sort with TBB parallelization.
Definition: sort_tbb.hpp:32
void radix_sort_tbb(T *first_, T *last_, const KeyFunc &key_)
Radix sort with TBB parallelization.
Definition: sort_tbb.hpp:138
Temporal buffers for radix sort.
Definition: sort_tbb.hpp:14
std::vector< T > sorted_buffer
Definition: sort_tbb.hpp:17
std::vector< std::uint64_t > tile_buckets
Definition: sort_tbb.hpp:15
std::vector< std::uint64_t > global_offsets
Definition: sort_tbb.hpp:16