spla
auto_sort_radix.hpp
Go to the documentation of this file.
1 // Copyright (c) 2021 - 2023 SparseLinearAlgebra
3 // Autogenerated file, do not modify
5 
6 #pragma once
7 
8 static const char source_sort_radix[] = R"(
9 
10 
11 #ifdef RADIX_SORT
12 // Number of different values in a mask, equals 1 << bits count
13  #define BITS_VALS 4
14 // Mask used to fetch bits, mask equals (1 << bits count) - 1
15  #define BITS_MASK 0x3
16 #endif
17 
18 __kernel void radix_sort_local(__global const uint* g_keys,
19  __global uint* g_offsets,
20  __global uint* g_blocks_size,
21  const uint n,
22  const uint shift) {
23  const uint gid = get_global_id(0);
24  const uint gpid = get_group_id(0);
25  const uint ngroups = get_num_groups(0);
26  const uint lid = get_local_id(0);
27 
28  __local uint s_mask[BLOCK_SIZE];
29  __local uint s_offsets[BLOCK_SIZE];
30 
31  s_mask[lid] = gid < n ? (g_keys[gid] >> shift) & BITS_MASK : 0;
32 
33  for (uint slot = 0; slot < BITS_VALS; slot += 1) {
34  barrier(CLK_LOCAL_MEM_FENCE);
35 
36  uint process_bit = 0;
37 
38  if (gid < n) {
39  process_bit = s_mask[lid] == slot ? 1 : 0;
40  }
41 
42  s_offsets[lid] = process_bit;
43 
44  for (uint offset = 1; offset < BLOCK_SIZE; offset *= 2) {
45  barrier(CLK_LOCAL_MEM_FENCE);
46  uint value = s_offsets[lid];
47 
48  if (offset <= lid) {
49  value += s_offsets[lid - offset];
50  }
51 
52  barrier(CLK_LOCAL_MEM_FENCE);
53  s_offsets[lid] = value;
54  }
55 
56  barrier(CLK_LOCAL_MEM_FENCE);
57 
58  if (gid < n && process_bit) {
59  g_offsets[gid] = s_offsets[lid] - 1;
60  }
61 
62  if (lid == 0) {
63  g_blocks_size[slot * ngroups + gpid] = s_offsets[BLOCK_SIZE - 1];
64  }
65  }
66 }
67 
68 __kernel void radix_sort_scatter(__global const uint* g_in_keys,
69  __global const TYPE* g_in_values,
70  __global uint* g_out_keys,
71  __global TYPE* g_out_values,
72  __global const uint* g_offsets,
73  __global const uint* g_blocks_offsets,
74  const uint n,
75  const uint shift) {
76  const uint gid = get_global_id(0);
77  const uint gpid = get_group_id(0);
78  const uint ngroups = get_num_groups(0);
79 
80 
81  if (gid < n) {
82  const uint key = g_in_keys[gid];
83  const TYPE values = g_in_values[gid];
84 
85  uint slot = (key >> shift) & BITS_MASK;
86  uint offset = g_blocks_offsets[slot * ngroups + gpid] + g_offsets[gid];
87 
88  g_out_keys[offset] = key;
89  g_out_values[offset] = values;
90  }
91 }
92 )";