spla
Loading...
Searching...
No Matches
cl_mxv.hpp
Go to the documentation of this file.
1/**********************************************************************************/
2/* This file is part of spla project */
3/* https://github.com/SparseLinearAlgebra/spla */
4/**********************************************************************************/
5/* MIT License */
6/* */
7/* Copyright (c) 2023 SparseLinearAlgebra */
8/* */
9/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10/* of this software and associated documentation files (the "Software"), to deal */
11/* in the Software without restriction, including without limitation the rights */
12/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13/* copies of the Software, and to permit persons to whom the Software is */
14/* furnished to do so, subject to the following conditions: */
15/* */
16/* The above copyright notice and this permission notice shall be included in all */
17/* copies or substantial portions of the Software. */
18/* */
19/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25/* SOFTWARE. */
26/**********************************************************************************/
27
28#ifndef SPLA_CL_MXV_HPP
29#define SPLA_CL_MXV_HPP
30
32
33#include <core/dispatcher.hpp>
34#include <core/registry.hpp>
35#include <core/tmatrix.hpp>
36#include <core/top.hpp>
37#include <core/tscalar.hpp>
38#include <core/ttype.hpp>
39#include <core/tvector.hpp>
40
41#include <opencl/cl_counter.hpp>
42#include <opencl/cl_debug.hpp>
43#include <opencl/cl_formats.hpp>
46
47#include <algorithm>
48#include <sstream>
49
50namespace spla {
51
52 template<typename T>
53 class Algo_mxv_masked_cl final : public RegistryAlgo {
54 public:
55 ~Algo_mxv_masked_cl() override = default;
56
57 std::string get_name() override {
58 return "mxv_masked";
59 }
60
61 std::string get_description() override {
62 return "parallel matrix-vector masked product on opencl device";
63 }
64
65 Status execute(const DispatchContext& ctx) override {
66 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
67 auto early_exit = t->get_desc_or_default()->get_early_exit();
68
69 if (early_exit) {
70 return execute_config_scalar(ctx);
71 } else {
72 return execute_vector(ctx);
73 }
74 }
75
76 private:
77 Status execute_vector(const DispatchContext& ctx) {
78 TIME_PROFILE_SCOPE("opencl/mxv/vector");
79
80 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
81
82 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
83 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
84 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
85 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
86 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
87 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
88 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
89 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
90
91 r->validate_wd(FormatVector::AccDense);
92 mask->validate_rw(FormatVector::AccDense);
93 M->validate_rw(FormatMatrix::AccCsr);
94 v->validate_rw(FormatVector::AccDense);
95
96 std::shared_ptr<CLProgram> program;
97 if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
98
99 auto* p_cl_r = r->template get<CLDenseVec<T>>();
100 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
101 auto* p_cl_M = M->template get<CLCsr<T>>();
102 auto* p_cl_v = v->template get<CLDenseVec<T>>();
103
104 auto* p_cl_acc = get_acc_cl();
105 auto& queue = p_cl_acc->get_queue_default();
106
107 auto kernel_vector = program->make_kernel("mxv_vector");
108 kernel_vector.setArg(0, p_cl_M->Ap);
109 kernel_vector.setArg(1, p_cl_M->Aj);
110 kernel_vector.setArg(2, p_cl_M->Ax);
111 kernel_vector.setArg(3, p_cl_v->Ax);
112 kernel_vector.setArg(4, p_cl_mask->Ax);
113 kernel_vector.setArg(5, p_cl_r->Ax);
114 kernel_vector.setArg(6, init->get_value());
115 kernel_vector.setArg(7, r->get_n_rows());
116
117 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_count, 1, 512);
118
119 cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
120 cl::NDRange exec_local(m_block_count, m_block_size);
121 CL_DISPATCH_PROFILED("exec", queue, kernel_vector, cl::NDRange(), exec_global, exec_local);
122
123 return Status::Ok;
124 }
125
126 Status execute_scalar(const DispatchContext& ctx) {
127 TIME_PROFILE_SCOPE("opencl/mxv/scalar");
128
129 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
130
131 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
132 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
133 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
134 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
135 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
136 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
137 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
138 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
139
140 r->validate_wd(FormatVector::AccDense);
141 mask->validate_rw(FormatVector::AccDense);
142 M->validate_rw(FormatMatrix::AccCsr);
143 v->validate_rw(FormatVector::AccDense);
144
145 std::shared_ptr<CLProgram> program;
146 if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
147
148 auto* p_cl_r = r->template get<CLDenseVec<T>>();
149 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
150 auto* p_cl_M = M->template get<CLCsr<T>>();
151 auto* p_cl_v = v->template get<CLDenseVec<T>>();
152 auto early_exit = t->get_desc_or_default()->get_early_exit();
153
154 auto* p_cl_acc = get_acc_cl();
155 auto& queue = p_cl_acc->get_queue_default();
156
157 auto kernel_scalar = program->make_kernel("mxv_scalar");
158 kernel_scalar.setArg(0, p_cl_M->Ap);
159 kernel_scalar.setArg(1, p_cl_M->Aj);
160 kernel_scalar.setArg(2, p_cl_M->Ax);
161 kernel_scalar.setArg(3, p_cl_v->Ax);
162 kernel_scalar.setArg(4, p_cl_mask->Ax);
163 kernel_scalar.setArg(5, p_cl_r->Ax);
164 kernel_scalar.setArg(6, init->get_value());
165 kernel_scalar.setArg(7, r->get_n_rows());
166 kernel_scalar.setArg(8, uint(early_exit));
167
168 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 512);
169
170 cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
171 cl::NDRange exec_local(m_block_size);
172 CL_DISPATCH_PROFILED("exec", queue, kernel_scalar, cl::NDRange(), exec_global, exec_local);
173
174 return Status::Ok;
175 }
176
177 Status execute_config_scalar(const DispatchContext& ctx) {
178 TIME_PROFILE_SCOPE("opencl/mxv/config-scalar");
179
180 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
181
182 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
183 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
184 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
185 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
186 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
187 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
188 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
189 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
190
191 r->validate_wd(FormatVector::AccDense);
192 mask->validate_rw(FormatVector::AccDense);
193 M->validate_rw(FormatMatrix::AccCsr);
194 v->validate_rw(FormatVector::AccDense);
195
196 std::shared_ptr<CLProgram> program;
197 if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
198
199 auto* p_cl_r = r->template get<CLDenseVec<T>>();
200 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
201 auto* p_cl_M = M->template get<CLCsr<T>>();
202 auto* p_cl_v = v->template get<CLDenseVec<T>>();
203 auto early_exit = t->get_desc_or_default()->get_early_exit();
204
205 auto* p_cl_acc = get_acc_cl();
206 auto& queue = p_cl_acc->get_queue_default();
207
208 uint config_size = 0;
209 cl::Buffer cl_config(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, sizeof(uint) * M->get_n_rows());
210 cl::Buffer cl_config_size(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | CL_MEM_COPY_HOST_PTR, sizeof(uint), &config_size);
211
212 auto kernel_config = program->make_kernel("mxv_config");
213 kernel_config.setArg(0, p_cl_mask->Ax);
214 kernel_config.setArg(1, p_cl_r->Ax);
215 kernel_config.setArg(2, cl_config);
216 kernel_config.setArg(3, cl_config_size);
217 kernel_config.setArg(4, init->get_value());
218 kernel_config.setArg(5, M->get_n_rows());
219
220 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024);
221
222 cl::NDRange config_global(m_block_size * n_groups_to_dispatch);
223 cl::NDRange config_local(m_block_size);
224 CL_DISPATCH_PROFILED("config", queue, kernel_config, cl::NDRange(), config_global, config_local);
225
226 CL_READ_PROFILED("config-size", queue, cl_config_size, true, 0, sizeof(config_size), &config_size);
227
228 auto kernel_config_scalar = program->make_kernel("mxv_config_scalar");
229 kernel_config_scalar.setArg(0, p_cl_M->Ap);
230 kernel_config_scalar.setArg(1, p_cl_M->Aj);
231 kernel_config_scalar.setArg(2, p_cl_M->Ax);
232 kernel_config_scalar.setArg(3, p_cl_v->Ax);
233 kernel_config_scalar.setArg(4, cl_config);
234 kernel_config_scalar.setArg(5, p_cl_r->Ax);
235 kernel_config_scalar.setArg(6, init->get_value());
236 kernel_config_scalar.setArg(7, config_size);
237 kernel_config_scalar.setArg(8, uint(early_exit));
238
239 n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024);
240
241 cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
242 cl::NDRange exec_local(m_block_size);
243 CL_DISPATCH_PROFILED("exec", queue, kernel_config_scalar, cl::NDRange(), exec_global, exec_local);
244
245 return Status::Ok;
246 }
247
248 bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
249 const ref_ptr<TOpBinary<T, T, T>>& op_add,
250 const ref_ptr<TOpSelect<T>>& op_select,
251 std::shared_ptr<CLProgram>& program) {
252 m_block_size = get_acc_cl()->get_wave_size();
253 m_block_count = 1;
254
255 assert(m_block_count >= 1);
256
257 CLProgramBuilder program_builder;
258 program_builder
259 .set_name("mxv")
260 .add_define("WARP_SIZE", get_acc_cl()->get_wave_size())
261 .add_define("BLOCK_SIZE", m_block_size)
262 .add_define("BLOCK_COUNT", m_block_count)
263 .add_type("TYPE", get_ttype<T>().template as<Type>())
264 .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
265 .add_op("OP_BINARY2", op_add.template as<OpBinary>())
266 .add_op("OP_SELECT", op_select.template as<OpSelect>())
267 .set_source(source_mxv)
268 .acquire();
269
270 program = program_builder.get_program();
271
272 return true;
273 }
274
275 private:
276 uint m_block_size = 0;
277 uint m_block_count = 0;
278 };
279
280}// namespace spla
281
282#endif//SPLA_CL_MXV_HPP
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition cl_debug.hpp:36
#define CL_READ_PROFILED(name, queue, buffer,...)
Definition cl_debug.hpp:53
Status of library operation execution.
Definition cl_mxv.hpp:53
~Algo_mxv_masked_cl() override=default
std::string get_description() override
Definition cl_mxv.hpp:61
std::string get_name() override
Definition cl_mxv.hpp:57
Status execute(const DispatchContext &ctx) override
Definition cl_mxv.hpp:65
Algorithm suitable to process schedule task based on task string key.
Definition registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition config.hpp:56
Definition algorithm.hpp:37
Execution context of a single task.
Definition dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition time_profiler.hpp:92