28 #ifndef SPLA_CL_MXV_HPP
29 #define SPLA_CL_MXV_HPP
62 return "parallel matrix-vector masked product on opencl device";
66 auto t = ctx.
task.template cast_safe<ScheduleTask_mxv_masked>();
67 auto early_exit = t->get_desc_or_default()->get_early_exit();
70 return execute_config_scalar(ctx);
72 return execute_vector(ctx);
80 auto t = ctx.
task.template cast_safe<ScheduleTask_mxv_masked>();
96 std::shared_ptr<CLProgram> program;
99 auto* p_cl_r = r->template get<CLDenseVec<T>>();
100 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
101 auto* p_cl_M = M->template get<CLCsr<T>>();
102 auto* p_cl_v = v->template get<CLDenseVec<T>>();
104 auto* p_cl_acc = get_acc_cl();
105 auto& queue = p_cl_acc->get_queue_default();
107 auto kernel_vector = program->make_kernel(
"mxv_vector");
108 kernel_vector.setArg(0, p_cl_M->Ap);
109 kernel_vector.setArg(1, p_cl_M->Aj);
110 kernel_vector.setArg(2, p_cl_M->Ax);
111 kernel_vector.setArg(3, p_cl_v->Ax);
112 kernel_vector.setArg(4, p_cl_mask->Ax);
113 kernel_vector.setArg(5, p_cl_r->Ax);
114 kernel_vector.setArg(6, init->get_value());
115 kernel_vector.setArg(7, r->get_n_rows());
117 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_count, 1, 512);
119 cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
120 cl::NDRange exec_local(m_block_count, m_block_size);
126 Status execute_scalar(
const DispatchContext& ctx) {
129 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
131 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
132 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
133 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
134 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
135 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
136 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
137 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
138 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
145 std::shared_ptr<CLProgram> program;
148 auto* p_cl_r = r->template get<CLDenseVec<T>>();
149 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
150 auto* p_cl_M = M->template get<CLCsr<T>>();
151 auto* p_cl_v = v->template get<CLDenseVec<T>>();
152 auto early_exit = t->get_desc_or_default()->get_early_exit();
154 auto* p_cl_acc = get_acc_cl();
155 auto& queue = p_cl_acc->get_queue_default();
157 auto kernel_scalar = program->make_kernel(
"mxv_scalar");
158 kernel_scalar.setArg(0, p_cl_M->Ap);
159 kernel_scalar.setArg(1, p_cl_M->Aj);
160 kernel_scalar.setArg(2, p_cl_M->Ax);
161 kernel_scalar.setArg(3, p_cl_v->Ax);
162 kernel_scalar.setArg(4, p_cl_mask->Ax);
163 kernel_scalar.setArg(5, p_cl_r->Ax);
164 kernel_scalar.setArg(6, init->get_value());
165 kernel_scalar.setArg(7, r->get_n_rows());
166 kernel_scalar.setArg(8,
uint(early_exit));
168 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 512);
170 cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
171 cl::NDRange exec_local(m_block_size);
177 Status execute_config_scalar(
const DispatchContext& ctx) {
180 auto t = ctx.task.template cast_safe<ScheduleTask_mxv_masked>();
182 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
183 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
184 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
185 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
186 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
187 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
188 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
189 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
196 std::shared_ptr<CLProgram> program;
199 auto* p_cl_r = r->template get<CLDenseVec<T>>();
200 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
201 auto* p_cl_M = M->template get<CLCsr<T>>();
202 auto* p_cl_v = v->template get<CLDenseVec<T>>();
203 auto early_exit = t->get_desc_or_default()->get_early_exit();
205 auto* p_cl_acc = get_acc_cl();
206 auto& queue = p_cl_acc->get_queue_default();
208 uint config_size = 0;
209 cl::Buffer cl_config(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
sizeof(
uint) * M->get_n_rows());
210 cl::Buffer cl_config_size(p_cl_acc->get_context(), CL_MEM_READ_WRITE | CL_MEM_HOST_READ_ONLY | CL_MEM_COPY_HOST_PTR,
sizeof(
uint), &config_size);
212 auto kernel_config = program->make_kernel(
"mxv_config");
213 kernel_config.setArg(0, p_cl_mask->Ax);
214 kernel_config.setArg(1, p_cl_r->Ax);
215 kernel_config.setArg(2, cl_config);
216 kernel_config.setArg(3, cl_config_size);
217 kernel_config.setArg(4, init->get_value());
218 kernel_config.setArg(5, M->get_n_rows());
220 uint n_groups_to_dispatch = div_up_clamp(r->get_n_rows(), m_block_size, 1, 1024);
222 cl::NDRange config_global(m_block_size * n_groups_to_dispatch);
223 cl::NDRange config_local(m_block_size);
224 CL_DISPATCH_PROFILED(
"config", queue, kernel_config, cl::NDRange(), config_global, config_local);
226 CL_READ_PROFILED(
"config-size", queue, cl_config_size,
true, 0,
sizeof(config_size), &config_size);
228 auto kernel_config_scalar = program->make_kernel(
"mxv_config_scalar");
229 kernel_config_scalar.setArg(0, p_cl_M->Ap);
230 kernel_config_scalar.setArg(1, p_cl_M->Aj);
231 kernel_config_scalar.setArg(2, p_cl_M->Ax);
232 kernel_config_scalar.setArg(3, p_cl_v->Ax);
233 kernel_config_scalar.setArg(4, cl_config);
234 kernel_config_scalar.setArg(5, p_cl_r->Ax);
235 kernel_config_scalar.setArg(6, init->get_value());
236 kernel_config_scalar.setArg(7, config_size);
237 kernel_config_scalar.setArg(8,
uint(early_exit));
239 n_groups_to_dispatch = div_up_clamp(config_size, m_block_size, 1, 1024);
241 cl::NDRange exec_global(m_block_size * n_groups_to_dispatch);
242 cl::NDRange exec_local(m_block_size);
243 CL_DISPATCH_PROFILED(
"exec", queue, kernel_config_scalar, cl::NDRange(), exec_global, exec_local);
248 bool ensure_kernel(
const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
249 const ref_ptr<TOpBinary<T, T, T>>& op_add,
250 const ref_ptr<TOpSelect<T>>& op_select,
251 std::shared_ptr<CLProgram>& program) {
252 m_block_size = get_acc_cl()->get_wave_size();
255 assert(m_block_count >= 1);
257 CLProgramBuilder program_builder;
260 .add_define(
"WARP_SIZE", get_acc_cl()->get_wave_size())
261 .add_define(
"BLOCK_SIZE", m_block_size)
262 .add_define(
"BLOCK_COUNT", m_block_count)
263 .add_type(
"TYPE", get_ttype<T>().
template as<Type>())
264 .add_op(
"OP_BINARY1", op_multiply.template as<OpBinary>())
265 .add_op(
"OP_BINARY2", op_add.template as<OpBinary>())
266 .add_op(
"OP_SELECT", op_select.template as<OpSelect>())
267 .set_source(source_mxv)
270 program = program_builder.get_program();
276 uint m_block_size = 0;
277 uint m_block_count = 0;
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition: cl_debug.hpp:36
#define CL_READ_PROFILED(name, queue, buffer,...)
Definition: cl_debug.hpp:53
Status of library operation execution.
Definition: cl_mxv.hpp:53
~Algo_mxv_masked_cl() override=default
std::string get_description() override
Definition: cl_mxv.hpp:61
std::string get_name() override
Definition: cl_mxv.hpp:57
Status execute(const DispatchContext &ctx) override
Definition: cl_mxv.hpp:65
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
Definition: algorithm.hpp:37
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92