spla
Loading...
Searching...
No Matches
auto_sort_radix.hpp
Go to the documentation of this file.
1
2// Copyright (c) 2021 - 2023 SparseLinearAlgebra
3// Autogenerated file, do not modify
5
6#pragma once
7
8static const char source_sort_radix[] = R"(
9
10
11#ifdef RADIX_SORT
12// Number of different values in a mask, equals 1 << bits count
13 #define BITS_VALS 4
14// Mask used to fetch bits, mask equals (1 << bits count) - 1
15 #define BITS_MASK 0x3
16#endif
17
18__kernel void radix_sort_local(__global const uint* g_keys,
19 __global uint* g_offsets,
20 __global uint* g_blocks_size,
21 const uint n,
22 const uint shift) {
23 const uint gid = get_global_id(0);
24 const uint gpid = get_group_id(0);
25 const uint ngroups = get_num_groups(0);
26 const uint lid = get_local_id(0);
27
28 __local uint s_mask[BLOCK_SIZE];
29 __local uint s_offsets[BLOCK_SIZE];
30
31 s_mask[lid] = gid < n ? (g_keys[gid] >> shift) & BITS_MASK : 0;
32
33 for (uint slot = 0; slot < BITS_VALS; slot += 1) {
34 barrier(CLK_LOCAL_MEM_FENCE);
35
36 uint process_bit = 0;
37
38 if (gid < n) {
39 process_bit = s_mask[lid] == slot ? 1 : 0;
40 }
41
42 s_offsets[lid] = process_bit;
43
44 for (uint offset = 1; offset < BLOCK_SIZE; offset *= 2) {
45 barrier(CLK_LOCAL_MEM_FENCE);
46 uint value = s_offsets[lid];
47
48 if (offset <= lid) {
49 value += s_offsets[lid - offset];
50 }
51
52 barrier(CLK_LOCAL_MEM_FENCE);
53 s_offsets[lid] = value;
54 }
55
56 barrier(CLK_LOCAL_MEM_FENCE);
57
58 if (gid < n && process_bit) {
59 g_offsets[gid] = s_offsets[lid] - 1;
60 }
61
62 if (lid == 0) {
63 g_blocks_size[slot * ngroups + gpid] = s_offsets[BLOCK_SIZE - 1];
64 }
65 }
66}
67
68__kernel void radix_sort_scatter(__global const uint* g_in_keys,
69 __global const TYPE* g_in_values,
70 __global uint* g_out_keys,
71 __global TYPE* g_out_values,
72 __global const uint* g_offsets,
73 __global const uint* g_blocks_offsets,
74 const uint n,
75 const uint shift) {
76 const uint gid = get_global_id(0);
77 const uint gpid = get_group_id(0);
78 const uint ngroups = get_num_groups(0);
79
80
81 if (gid < n) {
82 const uint key = g_in_keys[gid];
83 const TYPE values = g_in_values[gid];
84
85 uint slot = (key >> shift) & BITS_MASK;
86 uint offset = g_blocks_offsets[slot * ngroups + gpid] + g_offsets[gid];
87
88 g_out_keys[offset] = key;
89 g_out_values[offset] = values;
90 }
91}
92)";