40 auto* cl_acc = get_acc_cl();
41 const uint block_size = std::min(cl_acc->get_max_wgs(),
uint(256));
42 const uint values_per_block = block_size * 2;
52 .
add_type(
"TYPE", get_ttype<T>().
template as<Type>())
54 .
add_define(
"WARP_SIZE", cl_acc->get_wave_size())
55 .
add_define(
"LM_NUM_MEM_BANKS", cl_acc->get_num_of_mem_banks())
56 .
add_op(
"OP_BINARY", op.template as<OpBinary>())
60 uint n_groups_to_run = n / values_per_block + (n % values_per_block ? 1 : 0);
61 cl::Buffer cl_carry = tmp_alloc->
alloc(
sizeof(T) * n_groups_to_run);
63 auto kernel_prescan = builder.
make_kernel(
"prefix_sum_prescan_unroll");
64 kernel_prescan.setArg(0, values);
65 kernel_prescan.setArg(1, cl_carry);
66 kernel_prescan.setArg(2, n);
68 cl::NDRange prescan_global(n_groups_to_run * block_size);
69 cl::NDRange prescan_local(block_size);
70 queue.enqueueNDRangeKernel(kernel_prescan, cl::NDRange(), prescan_global, prescan_local);
72 if (n_groups_to_run > 1) {
75 auto kernel_propagate = builder.
make_kernel(
"prefix_sum_propagate");
76 kernel_propagate.setArg(0, values);
77 kernel_propagate.setArg(1, cl_carry);
78 kernel_propagate.setArg(2, n);
80 cl::NDRange propagate_global((n_groups_to_run - 1) * values_per_block);
81 cl::NDRange propagate_local(block_size);
82 queue.enqueueNDRangeKernel(kernel_propagate, cl::NDRange(), propagate_global, propagate_local);
void cl_exclusive_scan(cl::CommandQueue &queue, cl::Buffer &values, uint n, const ref_ptr< TOpBinary< T, T, T > > &op, CLAlloc *tmp_alloc)
Definition cl_prefix_sum.hpp:39