spla
auto_mxv.hpp
Go to the documentation of this file.
1 // Copyright (c) 2021 - 2023 SparseLinearAlgebra
3 // Autogenerated file, do not modify
5 
6 #pragma once
7 
8 static const char source_mxv[] = R"(
9 
10 
11 void reduction_group(uint block_size,
12  uint lid,
13  volatile __local TYPE* s_sum) {
14  if (BLOCK_SIZE >= block_size) {
15  if (lid < (block_size / 2)) {
16  s_sum[lid] = OP_BINARY2(s_sum[lid], s_sum[lid + (block_size / 2)]);
17  }
18  if (block_size > WARP_SIZE) {
19  barrier(CLK_LOCAL_MEM_FENCE);
20  }
21  }
22 }
23 
24 __kernel void mxv_vector(__global const uint* g_Ap,
25  __global const uint* g_Aj,
26  __global const TYPE* g_Ax,
27  __global const TYPE* g_vx,
28  __global const TYPE* g_mask,
29  __global TYPE* g_rx,
30  const TYPE init,
31  const uint n) {
32  const uint lid = get_local_id(1); // thread id in a row
33  const uint lsize = get_local_size(1); // num threads to process row
34  const uint lgroup = get_local_id(0); // num of rows inside a group
35  const uint gid = get_global_id(0); // id of row to touch
36  const uint gstride = get_global_size(0);// step between row ids
37 
38  __local TYPE s_sum[BLOCK_COUNT][BLOCK_SIZE];
39 
40  for (int row_id = gid; row_id < n; row_id += gstride) {
41  if (lid == 0) {
42  g_rx[row_id] = init;
43  }
44 
45  if (OP_SELECT(g_mask[row_id])) {
46  const uint start = g_Ap[row_id];
47  const uint end = g_Ap[row_id + 1];
48 
49  TYPE sum = init;
50 
51  for (uint i = start + lid; i < end; i += lsize) {
52  const uint col_id = g_Aj[i];
53  sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
54  }
55 
56  s_sum[lgroup][lid] = sum;
57  barrier(CLK_LOCAL_MEM_FENCE);
58 
59  reduction_group(64, lid, s_sum[lgroup]);
60  reduction_group(32, lid, s_sum[lgroup]);
61  reduction_group(16, lid, s_sum[lgroup]);
62  reduction_group(8, lid, s_sum[lgroup]);
63  reduction_group(4, lid, s_sum[lgroup]);
64  reduction_group(2, lid, s_sum[lgroup]);
65 
66  if (lid == 0) {
67  g_rx[row_id] = s_sum[lgroup][0];
68  }
69  }
70  }
71 }
72 
73 __kernel void mxv_scalar(__global const uint* g_Ap,
74  __global const uint* g_Aj,
75  __global const TYPE* g_Ax,
76  __global const TYPE* g_vx,
77  __global const TYPE* g_mask,
78  __global TYPE* g_rx,
79  const TYPE init,
80  const uint n,
81  const uint early_exit) {
82  const uint gid = get_global_id(0); // id of row to touch
83  const uint gstride = get_global_size(0);// step between row ids
84 
85  for (uint row_id = gid; row_id < n; row_id += gstride) {
86  TYPE sum = init;
87 
88  if (OP_SELECT(g_mask[row_id])) {
89  const uint start = g_Ap[row_id];
90  const uint end = g_Ap[row_id + 1];
91 
92  for (uint i = start; i < end; i += 1) {
93  const uint col_id = g_Aj[i];
94  sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
95 
96  if (early_exit && (sum != init)) break;
97  }
98  }
99 
100  g_rx[row_id] = sum;
101  }
102 }
103 
104 __kernel void mxv_config(__global const TYPE* g_mask,
105  __global TYPE* g_rx,
106  __global uint* g_config,
107  __global uint* g_config_size,
108  const TYPE init,
109  const uint n) {
110  const uint gid = get_global_id(0);
111  const uint gstride = get_global_size(0);
112 
113  for (uint i = gid; i < n; i += gstride) {
114  g_rx[i] = init;
115 
116  if (OP_SELECT(g_mask[i])) {
117  const uint id = atomic_inc(g_config_size);
118  g_config[id] = i;
119  }
120  }
121 }
122 
123 __kernel void mxv_config_scalar(__global const uint* g_Ap,
124  __global const uint* g_Aj,
125  __global const TYPE* g_Ax,
126  __global const TYPE* g_vx,
127  __global const uint* g_config,
128  __global TYPE* g_rx,
129  const TYPE init,
130  const uint n,
131  const uint early_exit) {
132  const uint gid = get_global_id(0); // id of row to touch
133  const uint gstride = get_global_size(0);// step between row ids
134 
135  for (uint cid = gid; cid < n; cid += gstride) {
136  const uint row_id = g_config[cid];
137  const uint start = g_Ap[row_id];
138  const uint end = g_Ap[row_id + 1];
139 
140  TYPE sum = init;
141 
142  for (uint i = start; i < end; i += 1) {
143  const uint col_id = g_Aj[i];
144  sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
145 
146  if (early_exit && (sum != init)) break;
147  }
148 
149  g_rx[row_id] = sum;
150  }
151 }
152 )";