28 #ifndef SPLA_CL_VXM_HPP
29 #define SPLA_CL_VXM_HPP
65 return "parallel vector-matrix masked product on opencl device";
69 return execute_sparse(ctx);
76 auto t = ctx.
task.template cast_safe<ScheduleTask_vxm_masked>();
91 std::shared_ptr<CLProgram> program;
94 auto* p_cl_r = r->template get<CLCooVec<T>>();
95 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
96 auto* p_cl_M = M->template get<CLCsr<T>>();
97 auto* p_cl_v = v->template get<CLCooVec<T>>();
99 auto* p_cl_acc = get_acc_cl();
100 auto* p_tmp_alloc = p_cl_acc->get_alloc_tmp();
101 auto& queue = p_cl_acc->get_queue_default();
107 auto kernel_sparse_count = program->make_kernel(
"vxm_sparse_count");
108 kernel_sparse_count.setArg(0, p_cl_v->Ai);
109 kernel_sparse_count.setArg(1, p_cl_v->Ax);
110 kernel_sparse_count.setArg(2, p_cl_M->Ap);
111 kernel_sparse_count.setArg(3, p_cl_M->Aj);
112 kernel_sparse_count.setArg(4, p_cl_mask->Ax);
113 kernel_sparse_count.setArg(5, cl_prods_count.
buffer());
114 kernel_sparse_count.setArg(6, p_cl_v->values);
116 uint n_groups_to_dispatch_v = div_up_clamp(p_cl_v->values, m_block_size, 1, 1024);
118 cl::NDRange count_global(m_block_size * n_groups_to_dispatch_v);
119 cl::NDRange count_local(m_block_size);
120 CL_DISPATCH_PROFILED(
"count", queue, kernel_sparse_count, cl::NDRange(), count_global, count_local);
122 CL_COUNTER_GET(
"copy_prods_count", queue, cl_prods_count, prods_count);
125 if (prods_count == 0) {
128 p_cl_r->Ai = cl::Buffer();
129 p_cl_r->Ax = cl::Buffer();
135 cl::Buffer cl_prodi = p_tmp_alloc->alloc(prods_count *
sizeof(
uint));
136 cl::Buffer cl_prodx = p_tmp_alloc->alloc(prods_count *
sizeof(T));
138 CLCounterWrapper cl_prods_offset;
141 auto kernel_sparse_collect = program->make_kernel(
"vxm_sparse_collect");
142 kernel_sparse_collect.setArg(0, p_cl_v->Ai);
143 kernel_sparse_collect.setArg(1, p_cl_v->Ax);
144 kernel_sparse_collect.setArg(2, p_cl_M->Ap);
145 kernel_sparse_collect.setArg(3, p_cl_M->Aj);
146 kernel_sparse_collect.setArg(4, p_cl_M->Ax);
147 kernel_sparse_collect.setArg(5, p_cl_mask->Ax);
148 kernel_sparse_collect.setArg(6, cl_prodi);
149 kernel_sparse_collect.setArg(7, cl_prodx);
150 kernel_sparse_collect.setArg(8, cl_prods_offset.buffer());
151 kernel_sparse_collect.setArg(9, p_cl_v->values);
153 cl::NDRange collect_global(m_block_size * n_groups_to_dispatch_v);
154 cl::NDRange collect_local(m_block_size);
155 CL_DISPATCH_PROFILED(
"collect", queue, kernel_sparse_collect, cl::NDRange(), collect_global, collect_local);
158 const uint max_key = r->get_n_rows() - 1;
159 const uint n_elements = prods_count;
160 cl_sort_by_key<T>(queue, cl_prodi, cl_prodx, n_elements, p_tmp_alloc, max_key);
164 cl::Buffer reduced_keys;
165 cl::Buffer reduced_values;
168 cl_reduce_by_key(queue, cl_prodi, cl_prodx, prods_count, reduced_keys, reduced_values, reduced_size, op_add, p_tmp_alloc);
171 p_cl_r->Ai = reduced_keys;
172 p_cl_r->Ax = reduced_values;
173 p_cl_r->values = reduced_size;
175 p_tmp_alloc->free_all();
180 bool ensure_kernel(
const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
181 const ref_ptr<TOpBinary<T, T, T>>& op_add,
182 const ref_ptr<TOpSelect<T>>& op_select,
183 std::shared_ptr<CLProgram>& program) {
184 m_block_size = get_acc_cl()->get_default_wgs();
187 assert(m_block_count >= 1);
189 CLProgramBuilder program_builder;
192 .add_define(
"BLOCK_SIZE", m_block_size)
193 .add_type(
"TYPE", get_ttype<T>().
template as<Type>())
194 .add_op(
"OP_BINARY1", op_multiply.template as<OpBinary>())
195 .add_op(
"OP_BINARY2", op_add.template as<OpBinary>())
196 .add_op(
"OP_SELECT", op_select.template as<OpSelect>())
197 .set_source(source_vxm)
200 program = program_builder.get_program();
206 uint m_block_size = 0;
207 uint m_block_count = 0;
#define CL_PROFILE_END()
Definition: cl_debug.hpp:111
#define CL_COUNTER_SET(name, queue, counter, value)
Definition: cl_debug.hpp:87
#define CL_COUNTER_GET(name, queue, counter, value)
Definition: cl_debug.hpp:70
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition: cl_debug.hpp:36
#define CL_PROFILE_BEGIN(name, queue)
Definition: cl_debug.hpp:104
Status of library operation execution.
Definition: cl_vxm.hpp:56
std::string get_name() override
Definition: cl_vxm.hpp:60
~Algo_vxm_masked_cl() override=default
Status execute(const DispatchContext &ctx) override
Definition: cl_vxm.hpp:68
std::string get_description() override
Definition: cl_vxm.hpp:64
Definition: cl_counter.hpp:58
cl::Buffer & buffer()
Definition: cl_counter.cpp:59
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
#define LOG_MSG(status, msg)
Definition: logger.hpp:66
Definition: algorithm.hpp:37
void cl_reduce_by_key(cl::CommandQueue &queue, const cl::Buffer &keys, const cl::Buffer &values, const uint size, cl::Buffer &unique_keys, cl::Buffer &reduce_values, uint &reduced_size, const ref_ptr< TOpBinary< T, T, T >> &reduce_op, CLAlloc *tmp_alloc)
Definition: cl_reduce_by_key.hpp:43
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92