spla
auto_mxmT_masked.hpp
Go to the documentation of this file.
1 // Copyright (c) 2021 - 2023 SparseLinearAlgebra
3 // Autogenerated file, do not modify
5 
6 #pragma once
7 
8 static const char source_mxmT_masked[] = R"(
9 
10 
11 __kernel void mxmT_masked_csr_scalar(__global const uint* g_Ap,
12  __global const uint* g_Aj,
13  __global const TYPE* g_Ax,
14  __global const uint* g_Bp,
15  __global const uint* g_Bj,
16  __global const TYPE* g_Bx,
17  __global const uint* g_maskp,
18  __global const uint* g_maskj,
19  __global const TYPE* g_maskx,
20  __global TYPE* g_Rx,
21  const TYPE init,
22  const uint n) {
23  const uint lid = get_local_id(1); // thread id in a row
24  const uint lsize = get_local_size(1); // size of local group
25  const uint gid = get_global_id(0); // id of row to touch
26  const uint gstride = get_global_size(0);// step between row ids
27 
28  for (uint row_id = gid; row_id < n; row_id += gstride) {
29  const uint mask_start = g_maskp[row_id];
30  const uint mask_end = g_maskp[row_id + 1];
31 
32  const uint A_start = g_Ap[row_id];
33  const uint A_end = g_Ap[row_id + 1];
34 
35  for (uint mask_k = mask_start + lid; mask_k < mask_end; mask_k += lsize) {
36  const uint mask_j = g_maskj[mask_k];
37  const TYPE mask_x = g_maskx[mask_k];
38 
39  TYPE r = init;
40 
41  if (OP_SELECT(mask_x)) {
42  const uint B_start = g_Bp[mask_j];
43  const uint B_end = g_Bp[mask_j + 1];
44 
45  uint A_it = A_start;
46  uint B_it = B_start;
47 
48  while (A_it < A_end && B_it < B_end) {
49  const uint A_j = g_Aj[A_it];
50  const uint B_j = g_Bj[B_it];
51 
52  if (A_j == B_j) {
53  r = OP_BINARY2(r, OP_BINARY1(g_Ax[A_it], g_Bx[B_it]));
54  ++A_it;
55  ++B_it;
56  } else if (A_j < B_j) {
57  ++A_it;
58  } else {
59  ++B_it;
60  }
61  }
62  }
63 
64  g_Rx[mask_k] = r;
65  }
66  }
67 }
68 )";