spla
Loading...
Searching...
No Matches
auto_mxv.hpp
Go to the documentation of this file.
1
2// Copyright (c) 2021 - 2023 SparseLinearAlgebra
3// Autogenerated file, do not modify
5
6#pragma once
7
8static const char source_mxv[] = R"(
9
10
11void reduction_group(uint block_size,
12 uint lid,
13 volatile __local TYPE* s_sum) {
14 if (BLOCK_SIZE >= block_size) {
15 if (lid < (block_size / 2)) {
16 s_sum[lid] = OP_BINARY2(s_sum[lid], s_sum[lid + (block_size / 2)]);
17 }
18 if (block_size > WARP_SIZE) {
19 barrier(CLK_LOCAL_MEM_FENCE);
20 }
21 }
22}
23
24__kernel void mxv_vector(__global const uint* g_Ap,
25 __global const uint* g_Aj,
26 __global const TYPE* g_Ax,
27 __global const TYPE* g_vx,
28 __global const TYPE* g_mask,
29 __global TYPE* g_rx,
30 const TYPE init,
31 const uint n) {
32 const uint lid = get_local_id(1); // thread id in a row
33 const uint lsize = get_local_size(1); // num threads to process row
34 const uint lgroup = get_local_id(0); // num of rows inside a group
35 const uint gid = get_global_id(0); // id of row to touch
36 const uint gstride = get_global_size(0);// step between row ids
37
38 __local TYPE s_sum[BLOCK_COUNT][BLOCK_SIZE];
39
40 for (int row_id = gid; row_id < n; row_id += gstride) {
41 if (lid == 0) {
42 g_rx[row_id] = init;
43 }
44
45 if (OP_SELECT(g_mask[row_id])) {
46 const uint start = g_Ap[row_id];
47 const uint end = g_Ap[row_id + 1];
48
49 TYPE sum = init;
50
51 for (uint i = start + lid; i < end; i += lsize) {
52 const uint col_id = g_Aj[i];
53 sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
54 }
55
56 s_sum[lgroup][lid] = sum;
57 barrier(CLK_LOCAL_MEM_FENCE);
58
59 reduction_group(64, lid, s_sum[lgroup]);
60 reduction_group(32, lid, s_sum[lgroup]);
61 reduction_group(16, lid, s_sum[lgroup]);
62 reduction_group(8, lid, s_sum[lgroup]);
63 reduction_group(4, lid, s_sum[lgroup]);
64 reduction_group(2, lid, s_sum[lgroup]);
65
66 if (lid == 0) {
67 g_rx[row_id] = s_sum[lgroup][0];
68 }
69 }
70 }
71}
72
73__kernel void mxv_scalar(__global const uint* g_Ap,
74 __global const uint* g_Aj,
75 __global const TYPE* g_Ax,
76 __global const TYPE* g_vx,
77 __global const TYPE* g_mask,
78 __global TYPE* g_rx,
79 const TYPE init,
80 const uint n,
81 const uint early_exit) {
82 const uint gid = get_global_id(0); // id of row to touch
83 const uint gstride = get_global_size(0);// step between row ids
84
85 for (uint row_id = gid; row_id < n; row_id += gstride) {
86 TYPE sum = init;
87
88 if (OP_SELECT(g_mask[row_id])) {
89 const uint start = g_Ap[row_id];
90 const uint end = g_Ap[row_id + 1];
91
92 for (uint i = start; i < end; i += 1) {
93 const uint col_id = g_Aj[i];
94 sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
95
96 if (early_exit && (sum != init)) break;
97 }
98 }
99
100 g_rx[row_id] = sum;
101 }
102}
103
104__kernel void mxv_config(__global const TYPE* g_mask,
105 __global TYPE* g_rx,
106 __global uint* g_config,
107 __global uint* g_config_size,
108 const TYPE init,
109 const uint n) {
110 const uint gid = get_global_id(0);
111 const uint gstride = get_global_size(0);
112
113 for (uint i = gid; i < n; i += gstride) {
114 g_rx[i] = init;
115
116 if (OP_SELECT(g_mask[i])) {
117 const uint id = atomic_inc(g_config_size);
118 g_config[id] = i;
119 }
120 }
121}
122
123__kernel void mxv_config_scalar(__global const uint* g_Ap,
124 __global const uint* g_Aj,
125 __global const TYPE* g_Ax,
126 __global const TYPE* g_vx,
127 __global const uint* g_config,
128 __global TYPE* g_rx,
129 const TYPE init,
130 const uint n,
131 const uint early_exit) {
132 const uint gid = get_global_id(0); // id of row to touch
133 const uint gstride = get_global_size(0);// step between row ids
134
135 for (uint cid = gid; cid < n; cid += gstride) {
136 const uint row_id = g_config[cid];
137 const uint start = g_Ap[row_id];
138 const uint end = g_Ap[row_id + 1];
139
140 TYPE sum = init;
141
142 for (uint i = start; i < end; i += 1) {
143 const uint col_id = g_Aj[i];
144 sum = OP_BINARY2(sum, OP_BINARY1(g_Ax[i], g_vx[col_id]));
145
146 if (early_exit && (sum != init)) break;
147 }
148
149 g_rx[row_id] = sum;
150 }
151}
152)";