68 auto t = ctx.
task.template cast_safe<ScheduleTask_mxmT_masked>();
84 std::shared_ptr<CLProgram> program;
87 auto* p_cl_R = R->template get<CLCsr<T>>();
88 const auto* p_cl_mask = mask->template get<CLCsr<T>>();
89 const auto* p_cl_A = A->template get<CLCsr<T>>();
90 const auto* p_cl_B = B->template get<CLCsr<T>>();
92 if (p_cl_mask->values == 0) {
95 if (p_cl_A->values == 0) {
98 if (p_cl_B->values == 0) {
102 auto* p_cl_acc = get_acc_cl();
103 auto& queue = p_cl_acc->get_queue_default();
106 queue.enqueueCopyBuffer(p_cl_mask->Ap, p_cl_R->Ap, 0, 0,
sizeof(
uint) * (R->get_n_rows() + 1));
107 queue.enqueueCopyBuffer(p_cl_mask->Aj, p_cl_R->Aj, 0, 0,
sizeof(
uint) * (p_cl_R->values));
109 auto kernel = program->make_kernel(
"mxmT_masked_csr_scalar");
110 kernel.setArg(0, p_cl_A->Ap);
111 kernel.setArg(1, p_cl_A->Aj);
112 kernel.setArg(2, p_cl_A->Ax);
113 kernel.setArg(3, p_cl_B->Ap);
114 kernel.setArg(4, p_cl_B->Aj);
115 kernel.setArg(5, p_cl_B->Ax);
116 kernel.setArg(6, p_cl_mask->Ap);
117 kernel.setArg(7, p_cl_mask->Aj);
118 kernel.setArg(8, p_cl_mask->Ax);
119 kernel.setArg(9, p_cl_R->Ax);
120 kernel.setArg(10, T(init->get_value()));
121 kernel.setArg(11, R->get_n_rows());
123 uint n_groups_to_dispatch = div_up_clamp(R->get_n_rows(), m_block_count, 1, 1024);
125 cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
126 cl::NDRange exec_local(m_block_count, m_block_size);
135 std::shared_ptr<CLProgram>& program) {
136 m_block_size = get_acc_cl()->get_default_wgs();
139 assert(m_block_count >= 1);
144 .
add_define(
"WARP_SIZE", get_acc_cl()->get_wave_size())
147 .
add_type(
"TYPE", get_ttype<T>().
template as<Type>())
148 .
add_op(
"OP_BINARY1", op_multiply.template as<OpBinary>())
149 .
add_op(
"OP_BINARY2", op_add.template as<OpBinary>())
150 .
add_op(
"OP_SELECT", op_select.template as<OpSelect>())