28 #ifndef SPLA_CL_SORT_BY_KEY_HPP
29 #define SPLA_CL_SORT_BY_KEY_HPP
50 auto* acc = get_acc_cl();
52 const uint pair_size =
sizeof(
uint) +
sizeof(T);
53 const uint local_size = floor_to_pow2(acc->get_max_local_mem() / pair_size);
54 const uint max_treads_per_block =
std::min(acc->get_max_wgs(), local_size / 2u);
56 assert(local_size > 2);
60 .
add_type(
"TYPE", get_ttype<T>().
template as<Type>())
65 auto kernel_local = builder.
make_kernel(
"bitonic_sort_local");
66 kernel_local.setArg(0, keys);
67 kernel_local.setArg(1, values);
68 kernel_local.setArg(2, size);
70 if (size <= local_size) {
71 const uint wave_size = acc->get_wave_size();
72 const uint n_threads = align(
std::min(size, max_treads_per_block), wave_size);
74 cl::NDRange global(n_threads);
75 cl::NDRange local(n_threads);
76 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
80 const uint n_groups = div_up(size, local_size);
82 cl::NDRange step_pre_sort_global(max_treads_per_block * n_groups);
83 cl::NDRange step_pre_sort_local(max_treads_per_block);
84 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), step_pre_sort_global, step_pre_sort_local);
86 auto kernel_global = builder.
make_kernel(
"bitonic_sort_global");
87 kernel_global.setArg(0, keys);
88 kernel_global.setArg(1, values);
89 kernel_global.setArg(2, size);
90 kernel_global.setArg(3,
uint(local_size * 2));
92 cl::NDRange step_final_global(acc->get_max_wgs());
93 cl::NDRange step_final_local(acc->get_max_wgs());
94 queue.enqueueNDRangeKernel(kernel_global, cl::NDRange(), step_final_global, step_final_local);
104 const uint BITS_COUNT = 4;
105 const uint BITS_VALS = 1 << BITS_COUNT;
106 const uint BITS_MASK = BITS_VALS - 1;
108 auto* cl_acc = get_acc_cl();
109 const uint block_size = cl_acc->get_default_wgs();
116 .
add_type(
"TYPE", get_ttype<T>().
template as<Type>())
120 const uint n_treads_total = align(n, block_size);
121 const uint n_groups = div_up(n, block_size);
122 const uint n_blocks_sizes = n_groups * BITS_VALS;
124 cl::Buffer cl_temp_keys;
125 cl::Buffer cl_temp_values;
126 cl_acc->get_alloc_general()->alloc_paired(
sizeof(
uint) * n,
sizeof(T) * n, cl_temp_keys, cl_temp_values);
128 cl::Buffer cl_offsets = tmp_alloc->
alloc(
sizeof(
uint) * n);
129 cl::Buffer cl_blocks_size = tmp_alloc->
alloc(
sizeof(
uint) * n_blocks_sizes);
131 auto kernel_local = builder.
make_kernel(
"radix_sort_local");
132 auto kernel_scatter = builder.
make_kernel(
"radix_sort_scatter");
134 const uint bits_in_max_key =
static_cast<uint>(std::floor(std::log2(
float(max_key)))) + 1;
135 const uint bits_aligned = align(bits_in_max_key, BITS_COUNT);
138 cl::Buffer in_keys = keys;
139 cl::Buffer in_values = values;
140 cl::Buffer out_keys = cl_temp_keys;
141 cl::Buffer out_values = cl_temp_values;
143 for (
uint shift = 0; shift <= max_bits - BITS_COUNT; shift += BITS_COUNT) {
144 cl::NDRange global(n_treads_total);
145 cl::NDRange local(block_size);
147 kernel_local.setArg(0, in_keys);
148 kernel_local.setArg(1, cl_offsets);
149 kernel_local.setArg(2, cl_blocks_size);
150 kernel_local.setArg(3, n);
151 kernel_local.setArg(4, shift);
152 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
156 kernel_scatter.setArg(0, in_keys);
157 kernel_scatter.setArg(1, in_values);
158 kernel_scatter.setArg(2, out_keys);
159 kernel_scatter.setArg(3, out_values);
160 kernel_scatter.setArg(4, cl_offsets);
161 kernel_scatter.setArg(5, cl_blocks_size);
162 kernel_scatter.setArg(6, n);
163 kernel_scatter.setArg(7, shift);
164 queue.enqueueNDRangeKernel(kernel_scatter, cl::NDRange(), global, local);
166 std::swap(in_keys, out_keys);
167 std::swap(in_values, out_values);
181 const uint sort_switch = 2u << 14u;
183 if (n <= sort_switch) {
184 cl_sort_by_key_bitonic<T>(queue, keys, values, n);
186 cl_sort_by_key_radix<T>(queue, keys, values, n, tmp_alloc, max_key);
Base class for any device-local opencl buffer allocator.
Definition: cl_alloc.hpp:39
virtual cl::Buffer alloc(std::size_t size)=0
Runtime opencl program builder.
Definition: cl_program_builder.hpp:55
CLProgramBuilder & add_define(const char *define, int value)
Definition: cl_program_builder.cpp:41
CLProgramBuilder & set_name(const char *name)
Definition: cl_program_builder.cpp:37
CLProgramBuilder & add_type(const char *alias, const ref_ptr< Type > &type)
Definition: cl_program_builder.cpp:45
cl::Kernel make_kernel(const char *name)
Definition: cl_program_builder.hpp:67
CLProgramBuilder & set_source(const char *source)
Definition: cl_program_builder.cpp:61
void acquire()
Definition: cl_program_builder.cpp:65
ref_ptr< OpBinary > PLUS_UINT
Definition: op.cpp:76
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
#define LOG_MSG(status, msg)
Definition: logger.hpp:66
Definition: algorithm.hpp:37
void cl_sort_by_key(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition: cl_sort_by_key.hpp:175
void cl_sort_by_key_radix(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition: cl_sort_by_key.hpp:98
T min(T a, T b)
Definition: op.cpp:152
void cl_sort_by_key_bitonic(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint size)
Definition: cl_sort_by_key.hpp:44