spla
cl_sort_by_key.hpp
Go to the documentation of this file.
1 /**********************************************************************************/
2 /* This file is part of spla project */
3 /* https://github.com/JetBrains-Research/spla */
4 /**********************************************************************************/
5 /* MIT License */
6 /* */
7 /* Copyright (c) 2023 SparseLinearAlgebra */
8 /* */
9 /* Permission is hereby granted, free of charge, to any person obtaining a copy */
10 /* of this software and associated documentation files (the "Software"), to deal */
11 /* in the Software without restriction, including without limitation the rights */
12 /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13 /* copies of the Software, and to permit persons to whom the Software is */
14 /* furnished to do so, subject to the following conditions: */
15 /* */
16 /* The above copyright notice and this permission notice shall be included in all */
17 /* copies or substantial portions of the Software. */
18 /* */
19 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20 /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21 /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22 /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23 /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24 /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25 /* SOFTWARE. */
26 /**********************************************************************************/
27 
28 #ifndef SPLA_CL_SORT_BY_KEY_HPP
29 #define SPLA_CL_SORT_BY_KEY_HPP
30 
32 #include <opencl/cl_alloc.hpp>
34 #include <opencl/cl_prefix_sum.hpp>
38 
39 #include <cmath>
40 
41 namespace spla {
42 
43  template<typename T>
44  void cl_sort_by_key_bitonic(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint size) {
45  if (size <= 1) {
46  LOG_MSG(Status::Ok, "nothing to do");
47  return;
48  }
49 
50  auto* acc = get_acc_cl();
51 
52  const uint pair_size = sizeof(uint) + sizeof(T);
53  const uint local_size = floor_to_pow2(acc->get_max_local_mem() / pair_size);
54  const uint max_treads_per_block = std::min(acc->get_max_wgs(), local_size / 2u);
55 
56  assert(local_size > 2);
57 
58  CLProgramBuilder builder;
59  builder.set_name("sort_bitonic")
60  .add_type("TYPE", get_ttype<T>().template as<Type>())
61  .add_define("BLOCK_SIZE", local_size)
62  .set_source(source_sort_bitonic)
63  .acquire();
64 
65  auto kernel_local = builder.make_kernel("bitonic_sort_local");
66  kernel_local.setArg(0, keys);
67  kernel_local.setArg(1, values);
68  kernel_local.setArg(2, size);
69 
70  if (size <= local_size) {
71  const uint wave_size = acc->get_wave_size();
72  const uint n_threads = align(std::min(size, max_treads_per_block), wave_size);
73 
74  cl::NDRange global(n_threads);
75  cl::NDRange local(n_threads);
76  queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
77  return;
78  }
79 
80  const uint n_groups = div_up(size, local_size);
81 
82  cl::NDRange step_pre_sort_global(max_treads_per_block * n_groups);
83  cl::NDRange step_pre_sort_local(max_treads_per_block);
84  queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), step_pre_sort_global, step_pre_sort_local);
85 
86  auto kernel_global = builder.make_kernel("bitonic_sort_global");
87  kernel_global.setArg(0, keys);
88  kernel_global.setArg(1, values);
89  kernel_global.setArg(2, size);
90  kernel_global.setArg(3, uint(local_size * 2));
91 
92  cl::NDRange step_final_global(acc->get_max_wgs());
93  cl::NDRange step_final_local(acc->get_max_wgs());
94  queue.enqueueNDRangeKernel(kernel_global, cl::NDRange(), step_final_global, step_final_local);
95  }
96 
97  template<typename T>
98  void cl_sort_by_key_radix(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint n, CLAlloc* tmp_alloc, uint max_key = 0xffffffff) {
99  if (n <= 1) {
100  LOG_MSG(Status::Ok, "nothing to do");
101  return;
102  }
103 
104  const uint BITS_COUNT = 4;
105  const uint BITS_VALS = 1 << BITS_COUNT;
106  const uint BITS_MASK = BITS_VALS - 1;
107 
108  auto* cl_acc = get_acc_cl();
109  const uint block_size = cl_acc->get_default_wgs();
110 
111  CLProgramBuilder builder;
112  builder.set_name("radix_sort")
113  .add_define("BLOCK_SIZE", block_size)
114  .add_define("BITS_VALS", BITS_VALS)
115  .add_define("BITS_MASK", BITS_MASK)
116  .add_type("TYPE", get_ttype<T>().template as<Type>())
117  .set_source(source_sort_radix)
118  .acquire();
119 
120  const uint n_treads_total = align(n, block_size);
121  const uint n_groups = div_up(n, block_size);
122  const uint n_blocks_sizes = n_groups * BITS_VALS;
123 
124  cl::Buffer cl_temp_keys;
125  cl::Buffer cl_temp_values;
126  cl_acc->get_alloc_general()->alloc_paired(sizeof(uint) * n, sizeof(T) * n, cl_temp_keys, cl_temp_values);
127 
128  cl::Buffer cl_offsets = tmp_alloc->alloc(sizeof(uint) * n);
129  cl::Buffer cl_blocks_size = tmp_alloc->alloc(sizeof(uint) * n_blocks_sizes);
130 
131  auto kernel_local = builder.make_kernel("radix_sort_local");
132  auto kernel_scatter = builder.make_kernel("radix_sort_scatter");
133 
134  const uint bits_in_max_key = static_cast<uint>(std::floor(std::log2(float(max_key)))) + 1;
135  const uint bits_aligned = align(bits_in_max_key, BITS_COUNT);
136  const uint max_bits = std::min(32u, bits_aligned);
137 
138  cl::Buffer in_keys = keys;
139  cl::Buffer in_values = values;
140  cl::Buffer out_keys = cl_temp_keys;
141  cl::Buffer out_values = cl_temp_values;
142 
143  for (uint shift = 0; shift <= max_bits - BITS_COUNT; shift += BITS_COUNT) {
144  cl::NDRange global(n_treads_total);
145  cl::NDRange local(block_size);
146 
147  kernel_local.setArg(0, in_keys);
148  kernel_local.setArg(1, cl_offsets);
149  kernel_local.setArg(2, cl_blocks_size);
150  kernel_local.setArg(3, n);
151  kernel_local.setArg(4, shift);
152  queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
153 
154  cl_exclusive_scan<uint>(queue, cl_blocks_size, n_blocks_sizes, PLUS_UINT.template cast_safe<TOpBinary<uint, uint, uint>>(), tmp_alloc);
155 
156  kernel_scatter.setArg(0, in_keys);
157  kernel_scatter.setArg(1, in_values);
158  kernel_scatter.setArg(2, out_keys);
159  kernel_scatter.setArg(3, out_values);
160  kernel_scatter.setArg(4, cl_offsets);
161  kernel_scatter.setArg(5, cl_blocks_size);
162  kernel_scatter.setArg(6, n);
163  kernel_scatter.setArg(7, shift);
164  queue.enqueueNDRangeKernel(kernel_scatter, cl::NDRange(), global, local);
165 
166  std::swap(in_keys, out_keys);
167  std::swap(in_values, out_values);
168  }
169 
170  keys = in_keys;
171  values = in_values;
172  }
173 
174  template<typename T>
175  void cl_sort_by_key(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint n, CLAlloc* tmp_alloc, uint max_key = 0xffffffff) {
176  if (n <= 1) {
177  LOG_MSG(Status::Ok, "nothing to do");
178  return;
179  }
180 
181  const uint sort_switch = 2u << 14u;
182 
183  if (n <= sort_switch) {
184  cl_sort_by_key_bitonic<T>(queue, keys, values, n);
185  } else {
186  cl_sort_by_key_radix<T>(queue, keys, values, n, tmp_alloc, max_key);
187  }
188  }
189 
190 }// namespace spla
191 
192 #endif//SPLA_CL_SORT_BY_KEY_HPP
Base class for any device-local opencl buffer allocator.
Definition: cl_alloc.hpp:39
virtual cl::Buffer alloc(std::size_t size)=0
Runtime opencl program builder.
Definition: cl_program_builder.hpp:55
CLProgramBuilder & add_define(const char *define, int value)
Definition: cl_program_builder.cpp:41
CLProgramBuilder & set_name(const char *name)
Definition: cl_program_builder.cpp:37
CLProgramBuilder & add_type(const char *alias, const ref_ptr< Type > &type)
Definition: cl_program_builder.cpp:45
cl::Kernel make_kernel(const char *name)
Definition: cl_program_builder.hpp:67
CLProgramBuilder & set_source(const char *source)
Definition: cl_program_builder.cpp:61
void acquire()
Definition: cl_program_builder.cpp:65
Definition: top.hpp:174
ref_ptr< OpBinary > PLUS_UINT
Definition: op.cpp:76
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
#define LOG_MSG(status, msg)
Definition: logger.hpp:66
Definition: algorithm.hpp:37
void cl_sort_by_key(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition: cl_sort_by_key.hpp:175
void cl_sort_by_key_radix(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition: cl_sort_by_key.hpp:98
T min(T a, T b)
Definition: op.cpp:152
void cl_sort_by_key_bitonic(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint size)
Definition: cl_sort_by_key.hpp:44