spla
Loading...
Searching...
No Matches
cl_sort_by_key.hpp
Go to the documentation of this file.
1/**********************************************************************************/
2/* This file is part of spla project */
3/* https://github.com/JetBrains-Research/spla */
4/**********************************************************************************/
5/* MIT License */
6/* */
7/* Copyright (c) 2023 SparseLinearAlgebra */
8/* */
9/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10/* of this software and associated documentation files (the "Software"), to deal */
11/* in the Software without restriction, including without limitation the rights */
12/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13/* copies of the Software, and to permit persons to whom the Software is */
14/* furnished to do so, subject to the following conditions: */
15/* */
16/* The above copyright notice and this permission notice shall be included in all */
17/* copies or substantial portions of the Software. */
18/* */
19/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25/* SOFTWARE. */
26/**********************************************************************************/
27
28#ifndef SPLA_CL_SORT_BY_KEY_HPP
29#define SPLA_CL_SORT_BY_KEY_HPP
30
32#include <opencl/cl_alloc.hpp>
38
39#include <cmath>
40
41namespace spla {
42
43 template<typename T>
44 void cl_sort_by_key_bitonic(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint size) {
45 if (size <= 1) {
46 LOG_MSG(Status::Ok, "nothing to do");
47 return;
48 }
49
50 auto* acc = get_acc_cl();
51
52 const uint pair_size = sizeof(uint) + sizeof(T);
53 const uint local_size = floor_to_pow2(acc->get_max_local_mem() / pair_size);
54 const uint max_treads_per_block = std::min(acc->get_max_wgs(), local_size / 2u);
55
56 assert(local_size > 2);
57
58 CLProgramBuilder builder;
59 builder.set_name("sort_bitonic")
60 .add_type("TYPE", get_ttype<T>().template as<Type>())
61 .add_define("BLOCK_SIZE", local_size)
62 .set_source(source_sort_bitonic)
63 .acquire();
64
65 auto kernel_local = builder.make_kernel("bitonic_sort_local");
66 kernel_local.setArg(0, keys);
67 kernel_local.setArg(1, values);
68 kernel_local.setArg(2, size);
69
70 if (size <= local_size) {
71 const uint wave_size = acc->get_wave_size();
72 const uint n_threads = align(std::min(size, max_treads_per_block), wave_size);
73
74 cl::NDRange global(n_threads);
75 cl::NDRange local(n_threads);
76 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
77 return;
78 }
79
80 const uint n_groups = div_up(size, local_size);
81
82 cl::NDRange step_pre_sort_global(max_treads_per_block * n_groups);
83 cl::NDRange step_pre_sort_local(max_treads_per_block);
84 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), step_pre_sort_global, step_pre_sort_local);
85
86 auto kernel_global = builder.make_kernel("bitonic_sort_global");
87 kernel_global.setArg(0, keys);
88 kernel_global.setArg(1, values);
89 kernel_global.setArg(2, size);
90 kernel_global.setArg(3, uint(local_size * 2));
91
92 cl::NDRange step_final_global(acc->get_max_wgs());
93 cl::NDRange step_final_local(acc->get_max_wgs());
94 queue.enqueueNDRangeKernel(kernel_global, cl::NDRange(), step_final_global, step_final_local);
95 }
96
97 template<typename T>
98 void cl_sort_by_key_radix(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint n, CLAlloc* tmp_alloc, uint max_key = 0xffffffff) {
99 if (n <= 1) {
100 LOG_MSG(Status::Ok, "nothing to do");
101 return;
102 }
103
104 const uint BITS_COUNT = 4;
105 const uint BITS_VALS = 1 << BITS_COUNT;
106 const uint BITS_MASK = BITS_VALS - 1;
107
108 auto* cl_acc = get_acc_cl();
109 const uint block_size = cl_acc->get_default_wgs();
110
111 CLProgramBuilder builder;
112 builder.set_name("radix_sort")
113 .add_define("BLOCK_SIZE", block_size)
114 .add_define("BITS_VALS", BITS_VALS)
115 .add_define("BITS_MASK", BITS_MASK)
116 .add_type("TYPE", get_ttype<T>().template as<Type>())
117 .set_source(source_sort_radix)
118 .acquire();
119
120 const uint n_treads_total = align(n, block_size);
121 const uint n_groups = div_up(n, block_size);
122 const uint n_blocks_sizes = n_groups * BITS_VALS;
123
124 cl::Buffer cl_temp_keys;
125 cl::Buffer cl_temp_values;
126 cl_acc->get_alloc_general()->alloc_paired(sizeof(uint) * n, sizeof(T) * n, cl_temp_keys, cl_temp_values);
127
128 cl::Buffer cl_offsets = tmp_alloc->alloc(sizeof(uint) * n);
129 cl::Buffer cl_blocks_size = tmp_alloc->alloc(sizeof(uint) * n_blocks_sizes);
130
131 auto kernel_local = builder.make_kernel("radix_sort_local");
132 auto kernel_scatter = builder.make_kernel("radix_sort_scatter");
133
134 const uint bits_in_max_key = static_cast<uint>(std::floor(std::log2(float(max_key)))) + 1;
135 const uint bits_aligned = align(bits_in_max_key, BITS_COUNT);
136 const uint max_bits = std::min(32u, bits_aligned);
137
138 cl::Buffer in_keys = keys;
139 cl::Buffer in_values = values;
140 cl::Buffer out_keys = cl_temp_keys;
141 cl::Buffer out_values = cl_temp_values;
142
143 for (uint shift = 0; shift <= max_bits - BITS_COUNT; shift += BITS_COUNT) {
144 cl::NDRange global(n_treads_total);
145 cl::NDRange local(block_size);
146
147 kernel_local.setArg(0, in_keys);
148 kernel_local.setArg(1, cl_offsets);
149 kernel_local.setArg(2, cl_blocks_size);
150 kernel_local.setArg(3, n);
151 kernel_local.setArg(4, shift);
152 queue.enqueueNDRangeKernel(kernel_local, cl::NDRange(), global, local);
153
154 cl_exclusive_scan<uint>(queue, cl_blocks_size, n_blocks_sizes, PLUS_UINT.template cast_safe<TOpBinary<uint, uint, uint>>(), tmp_alloc);
155
156 kernel_scatter.setArg(0, in_keys);
157 kernel_scatter.setArg(1, in_values);
158 kernel_scatter.setArg(2, out_keys);
159 kernel_scatter.setArg(3, out_values);
160 kernel_scatter.setArg(4, cl_offsets);
161 kernel_scatter.setArg(5, cl_blocks_size);
162 kernel_scatter.setArg(6, n);
163 kernel_scatter.setArg(7, shift);
164 queue.enqueueNDRangeKernel(kernel_scatter, cl::NDRange(), global, local);
165
166 std::swap(in_keys, out_keys);
167 std::swap(in_values, out_values);
168 }
169
170 keys = in_keys;
171 values = in_values;
172 }
173
174 template<typename T>
175 void cl_sort_by_key(cl::CommandQueue& queue, cl::Buffer& keys, cl::Buffer& values, uint n, CLAlloc* tmp_alloc, uint max_key = 0xffffffff) {
176 if (n <= 1) {
177 LOG_MSG(Status::Ok, "nothing to do");
178 return;
179 }
180
181 const uint sort_switch = 2u << 14u;
182
183 if (n <= sort_switch) {
184 cl_sort_by_key_bitonic<T>(queue, keys, values, n);
185 } else {
186 cl_sort_by_key_radix<T>(queue, keys, values, n, tmp_alloc, max_key);
187 }
188 }
189
190}// namespace spla
191
192#endif//SPLA_CL_SORT_BY_KEY_HPP
Base class for any device-local opencl buffer allocator.
Definition cl_alloc.hpp:39
virtual cl::Buffer alloc(std::size_t size)=0
Runtime opencl program builder.
Definition cl_program_builder.hpp:55
CLProgramBuilder & add_define(const char *define, int value)
Definition cl_program_builder.cpp:41
CLProgramBuilder & set_name(const char *name)
Definition cl_program_builder.cpp:37
CLProgramBuilder & add_type(const char *alias, const ref_ptr< Type > &type)
Definition cl_program_builder.cpp:45
cl::Kernel make_kernel(const char *name)
Definition cl_program_builder.hpp:67
CLProgramBuilder & set_source(const char *source)
Definition cl_program_builder.cpp:61
void acquire()
Definition cl_program_builder.cpp:65
Definition top.hpp:174
ref_ptr< OpBinary > PLUS_UINT
Definition op.cpp:76
std::uint32_t uint
Library index and size type.
Definition config.hpp:56
#define LOG_MSG(status, msg)
Definition logger.hpp:66
Definition algorithm.hpp:37
void cl_sort_by_key(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition cl_sort_by_key.hpp:175
void cl_sort_by_key_radix(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition cl_sort_by_key.hpp:98
void cl_exclusive_scan(cl::CommandQueue &queue, cl::Buffer &values, uint n, const ref_ptr< TOpBinary< T, T, T > > &op, CLAlloc *tmp_alloc)
Definition cl_prefix_sum.hpp:39
void cl_sort_by_key_bitonic(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint size)
Definition cl_sort_by_key.hpp:44