spla
cl_mxv.hpp
Go to the documentation of this file.
1 /**********************************************************************************/
2 /* This file is part of spla project */
3 /* https://github.com/SparseLinearAlgebra/spla */
4 /**********************************************************************************/
5 /* MIT License */
6 /* */
7 /* Copyright (c) 2023 SparseLinearAlgebra */
8 /* */
9 /* Permission is hereby granted, free of charge, to any person obtaining a copy */
10 /* of this software and associated documentation files (the "Software"), to deal */
11 /* in the Software without restriction, including without limitation the rights */
12 /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13 /* copies of the Software, and to permit persons to whom the Software is */
14 /* furnished to do so, subject to the following conditions: */
15 /* */
16 /* The above copyright notice and this permission notice shall be included in all */
17 /* copies or substantial portions of the Software. */
18 /* */
19 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20 /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21 /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22 /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23 /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24 /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25 /* SOFTWARE. */
26 /**********************************************************************************/
27 
28 #ifndef SPLA_CL_MXV_HPP
29 #define SPLA_CL_MXV_HPP
30 
32 
33 #include <core/dispatcher.hpp>
34 #include <core/registry.hpp>
35 #include <core/tmatrix.hpp>
36 #include <core/top.hpp>
37 #include <core/tscalar.hpp>
38 #include <core/ttype.hpp>
39 #include <core/tvector.hpp>
40 
41 #include <opencl/cl_counter.hpp>
42 #include <opencl/cl_debug.hpp>
43 #include <opencl/cl_formats.hpp>
46 
47 #include <algorithm>
48 #include <sstream>
49 
50 namespace spla {
51 
52  template<typename T>
53  class Algo_mxv_masked_cl final : public RegistryAlgo {
54  public:
55  ~Algo_mxv_masked_cl() override = default;
56 
57  std::string get_name() override {
58  return "mxv_masked";
59  }
60 
61  std::string get_description() override {
62  return "parallel matrix-vector masked product on opencl device";
63  }
64 
65  Status execute(const DispatchContext& ctx) override {
66  auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
67  auto early_exit = t->get_desc_or_default()->get_early_exit();
68 
69  if (early_exit) {
70  return execute_config_scalar(ctx);
71  } else {
72  return execute_vector(ctx);
73  }
74  }
75 
76  private:
77  Status execute_vector(const DispatchContext& ctx) {
78  TIME_PROFILE_SCOPE("opencl/mxv/vector");
79 
80  auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
81 
82  ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
83  ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
84  ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
85  ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
86  ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
87  ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
88  ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
89  ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
90 
91  r->validate_wd(FormatVector::AccDense);
92  mask->validate_rw(FormatVector::AccDense);
93  M->validate_rw(FormatMatrix::AccCsr);
94  v->validate_rw(FormatVector::AccDense);
95 
96  std::shared_ptr<CLProgram> program;
97  if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
98 
99  auto* p_cl_r = r->template get<CLDenseVec<T>>();
100  auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
101  auto* p_cl_M = M->template get<CLCsr<T>>();
102  auto* p_cl_v = v->template get<CLDenseVec<T>>();
103 
104  auto* p_cl_acc = get_acc_cl();
105  auto& queue = p_cl_acc->get_queue_default();
106 
107  auto kernel_vector = program->make_kernel("mxv_vector");
108  kernel_vector.setArg(0, p_cl_M->Ap);
109  kernel_vector.setArg(1, p_cl_M->Aj);
110  kernel_vector.setArg(2, p_cl_M->Ax);
111  kernel_vector.setArg(3, p_cl_v->Ax);
112  kernel_vector.setArg(4, p_cl_mask->Ax);
113  kernel_vector.setArg(5, p_cl_r->Ax);
114  kernel_vector.setArg(6, init->get_value());
115  kernel_vector.setArg(7, r->get_n_rows());
116 
117  uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_count, 1, 512);
118 
119  cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
120  cl::NDRange exec_local(m_block_count, m_block_size);
121  CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), exec_global, exec_local);
122 
123  return Status::Ok;
124  }
125 
126  Status execute_scalar(const DispatchContext& ctx) {
127  TIME_PROFILE_SCOPE("opencl/mxv/scalar");
128 
129  auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
130 
131  ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
132  ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
133  ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
134  ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
135  ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
136  ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
137  ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
138  ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
139 
140  r->validate_wd(FormatVector::AccDense);
141  mask->validate_rw(FormatVector::AccDense);
142  M->validate_rw(FormatMatrix::AccCsr);
143  v->validate_rw(FormatVector::AccDense);
144 
145  std::shared_ptr<CLProgram> program;
146  if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
147 
148  auto* p_cl_r = r->template get<CLDenseVec<T>>();
149  auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
150  auto* p_cl_M = M->template get<CLCsr<T>>();
151  auto* p_cl_v = v->template get<CLDenseVec<T>>();
152  auto early_exit = t->get_desc_or_default()->get_early_exit();
153 
154  auto* p_cl_acc = get_acc_cl();
155  auto& queue = p_cl_acc->get_queue_default();
156 
157  auto kernel_scalar = program->make_kernel("mxv_scalar");
158  kernel_scalar.setArg(0, p_cl_M->Ap);
159  kernel_scalar.setArg(1, p_cl_M->Aj);
160  kernel_scalar.setArg(2, p_cl_M->Ax);
161  kernel_scalar.setArg(3, p_cl_v->Ax);
162  kernel_scalar.setArg(4, p_cl_mask->Ax);
163  kernel_scalar.setArg(5, p_cl_r->Ax);
164  kernel_scalar.setArg(6, init->get_value());
165  kernel_scalar.setArg(7, r->get_n_rows());
166  kernel_scalar.setArg(8, uint(early_exit));
167 
168  uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 512);
169 
170  cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
171  cl::NDRange exec_local(m_block_size);
172  CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), exec_global, exec_local);
173 
174  return Status::Ok;
175  }
176 
177  Status execute_config_scalar(const DispatchContext& ctx) {
178  TIME_PROFILE_SCOPE("opencl/mxv/config-scalar");
179 
180  auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
181 
182  ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
183  ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
184  ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
185  ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
186  ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
187  ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
188  ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
189  ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
190 
191  r->validate_wd(FormatVector::AccDense);
192  mask->validate_rw(FormatVector::AccDense);
193  M->validate_rw(FormatMatrix::AccCsr);
194  v->validate_rw(FormatVector::AccDense);
195 
196  std::shared_ptr<CLProgram> program;
197  if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
198 
199  auto* p_cl_r = r->template get<CLDenseVec<T>>();
200  auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
201  auto* p_cl_M = M->template get<CLCsr<T>>();
202  auto* p_cl_v = v->template get<CLDenseVec<T>>();
203  auto early_exit = t->get_desc_or_default()->get_early_exit();
204 
205  auto* p_cl_acc = get_acc_cl();
206  auto& queue = p_cl_acc->get_queue_default();
207 
208  uint config_size = 0;
209  cl::Buffer cl_config(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, sizeof(uint) * M->get_n_rows());
210  cl::Buffer cl_config_size(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | CL_MEM_COPY_HOST_PTR, sizeof(uint), &config_size);
211 
212  auto kernel_config = program->make_kernel("mxv_config");
213  kernel_config.setArg(0, p_cl_mask->Ax);
214  kernel_config.setArg(1, p_cl_r->Ax);
215  kernel_config.setArg(2, cl_config);
216  kernel_config.setArg(3, cl_config_size);
217  kernel_config.setArg(4, init->get_value());
218  kernel_config.setArg(5, M->get_n_rows());
219 
220  uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024);
221 
222  cl::NDRange config_global(m_block_size * n_groups_to_dispatch);
223  cl::NDRange config_local(m_block_size);
224  CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), config_global, config_local);
225 
226  CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, sizeof(config_size), &config_size);
227 
228  auto kernel_config_scalar = program->make_kernel("mxv_config_scalar");
229  kernel_config_scalar.setArg(0, p_cl_M->Ap);
230  kernel_config_scalar.setArg(1, p_cl_M->Aj);
231  kernel_config_scalar.setArg(2, p_cl_M->Ax);
232  kernel_config_scalar.setArg(3, p_cl_v->Ax);
233  kernel_config_scalar.setArg(4, cl_config);
234  kernel_config_scalar.setArg(5, p_cl_r->Ax);
235  kernel_config_scalar.setArg(6, init->get_value());
236  kernel_config_scalar.setArg(7, config_size);
237  kernel_config_scalar.setArg(8, uint(early_exit));
238 
239  n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024);
240 
241  cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
242  cl::NDRange exec_local(m_block_size);
243  CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), exec_global, exec_local);
244 
245  return Status::Ok;
246  }
247 
248  bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
249  const ref_ptr<TOpBinary<T, T, T>>& op_add,
250  const ref_ptr<TOpSelect<T>>& op_select,
251  std::shared_ptr<CLProgram>& program) {
252  m_block_size = get_acc_cl()->get_wave_size();
253  m_block_count = 1;
254 
255  assert(m_block_count >= 1);
256 
257  CLProgramBuilder program_builder;
258  program_builder
259  .set_name("mxv")
260  .add_define("WARP_SIZE", get_acc_cl()->get_wave_size())
261  .add_define("BLOCK_SIZE", m_block_size)
262  .add_define("BLOCK_COUNT", m_block_count)
263  .add_type("TYPE", get_ttype<T>().template as<Type>())
264  .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
265  .add_op("OP_BINARY2", op_add.template as<OpBinary>())
266  .add_op("OP_SELECT", op_select.template as<OpSelect>())
267  .set_source(source_mxv)
268  .acquire();
269 
270  program = program_builder.get_program();
271 
272  return true;
273  }
274 
275  private:
276  uint m_block_size = 0;
277  uint m_block_count = 0;
278  };
279 
280 }// namespace spla
281 
282 #endif//SPLA_CL_MXV_HPP
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition: cl_debug.hpp:36
#define CL_READ_PROFILED(name, queue, buffer,...)
Definition: cl_debug.hpp:53
Status of library operation execution.
Definition: cl_mxv.hpp:53
~Algo_mxv_masked_cl() override=default
std::string get_description() override
Definition: cl_mxv.hpp:61
std::string get_name() override
Definition: cl_mxv.hpp:57
Status execute(const DispatchContext &ctx) override
Definition: cl_mxv.hpp:65
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
Definition: algorithm.hpp:37
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92