spla
cl_vxm.hpp
Go to the documentation of this file.
1 /**********************************************************************************/
2 /* This file is part of spla project */
3 /* https://github.com/SparseLinearAlgebra/spla */
4 /**********************************************************************************/
5 /* MIT License */
6 /* */
7 /* Copyright (c) 2023 SparseLinearAlgebra */
8 /* */
9 /* Permission is hereby granted, free of charge, to any person obtaining a copy */
10 /* of this software and associated documentation files (the "Software"), to deal */
11 /* in the Software without restriction, including without limitation the rights */
12 /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13 /* copies of the Software, and to permit persons to whom the Software is */
14 /* furnished to do so, subject to the following conditions: */
15 /* */
16 /* The above copyright notice and this permission notice shall be included in all */
17 /* copies or substantial portions of the Software. */
18 /* */
19 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20 /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21 /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22 /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23 /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24 /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25 /* SOFTWARE. */
26 /**********************************************************************************/
27 
28 #ifndef SPLA_CL_VXM_HPP
29 #define SPLA_CL_VXM_HPP
30 
32 
33 #include <core/dispatcher.hpp>
34 #include <core/registry.hpp>
35 #include <core/tmatrix.hpp>
36 #include <core/top.hpp>
37 #include <core/tscalar.hpp>
38 #include <core/ttype.hpp>
39 #include <core/tvector.hpp>
40 
42 #include <opencl/cl_counter.hpp>
43 #include <opencl/cl_debug.hpp>
44 #include <opencl/cl_formats.hpp>
49 
50 #include <algorithm>
51 #include <sstream>
52 
53 namespace spla {
54 
55  template<typename T>
56  class Algo_vxm_masked_cl final : public RegistryAlgo {
57  public:
58  ~Algo_vxm_masked_cl() override = default;
59 
60  std::string get_name() override {
61  return "vxm_masked";
62  }
63 
64  std::string get_description() override {
65  return "parallel vector-matrix masked product on opencl device";
66  }
67 
68  Status execute(const DispatchContext& ctx) override {
69  return execute_sparse(ctx);
70  }
71 
72  private:
73  Status execute_sparse(const DispatchContext& ctx) {
74  TIME_PROFILE_SCOPE("opencl/vxm/sparse");
75 
76  auto t = ctx.task.template cast_safe<ScheduleTask_vxm_masked>();
77 
78  ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
79  ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
80  ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
81  ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
82  ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
83  ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
84  ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
85  ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
86 
87  r->validate_wd(FormatVector::AccCoo);
88  mask->validate_rw(FormatVector::AccDense);
89  M->validate_rw(FormatMatrix::AccCsr);
90  v->validate_rw(FormatVector::AccCoo);
91  std::shared_ptr<CLProgram> program;
92  if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
93 
94  auto* p_cl_r = r->template get<CLCooVec<T>>();
95  auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
96  auto* p_cl_M = M->template get<CLCsr<T>>();
97  auto* p_cl_v = v->template get<CLCooVec<T>>();
98 
99  auto* p_cl_acc = get_acc_cl();
100  auto* p_tmp_alloc = p_cl_acc->get_alloc_tmp();
101  auto& queue = p_cl_acc->get_queue_default();
102 
103  uint prods_count;
104  CLCounterWrapper cl_prods_count;
105  CL_COUNTER_SET("init-prods-cnt", queue, cl_prods_count, 0);
106 
107  auto kernel_sparse_count = program->make_kernel("vxm_sparse_count");
108  kernel_sparse_count.setArg(0, p_cl_v->Ai);
109  kernel_sparse_count.setArg(1, p_cl_v->Ax);
110  kernel_sparse_count.setArg(2, p_cl_M->Ap);
111  kernel_sparse_count.setArg(3, p_cl_M->Aj);
112  kernel_sparse_count.setArg(4, p_cl_mask->Ax);
113  kernel_sparse_count.setArg(5, cl_prods_count.buffer());
114  kernel_sparse_count.setArg(6, p_cl_v->values);
115 
116  uint n_groups_to_dispatch_v = div_up_clamp(p_cl_v->values, m_block_size, 1, 1024);
117 
118  cl::NDRange count_global(m_block_size * n_groups_to_dispatch_v);
119  cl::NDRange count_local(m_block_size);
120  CL_DISPATCH_PROFILED("count", queue, kernel_sparse_count, cl::NDRange(), count_global, count_local);
121 
122  CL_COUNTER_GET("copy_prods_count", queue, cl_prods_count, prods_count);
123  LOG_MSG(Status::Ok, "temporary vi * A[,*] count " << prods_count);
124 
125  if (prods_count == 0) {
126  LOG_MSG(Status::Ok, "nothing to do");
127 
128  p_cl_r->Ai = cl::Buffer();
129  p_cl_r->Ax = cl::Buffer();
130  p_cl_r->values = 0;
131 
132  return Status::Ok;
133  }
134 
135  cl::Buffer cl_prodi = p_tmp_alloc->alloc(prods_count * sizeof(uint));
136  cl::Buffer cl_prodx = p_tmp_alloc->alloc(prods_count * sizeof(T));
137 
138  CLCounterWrapper cl_prods_offset;
139  CL_COUNTER_SET("init-offsets-cnt", queue, cl_prods_offset, 0);
140 
141  auto kernel_sparse_collect = program->make_kernel("vxm_sparse_collect");
142  kernel_sparse_collect.setArg(0, p_cl_v->Ai);
143  kernel_sparse_collect.setArg(1, p_cl_v->Ax);
144  kernel_sparse_collect.setArg(2, p_cl_M->Ap);
145  kernel_sparse_collect.setArg(3, p_cl_M->Aj);
146  kernel_sparse_collect.setArg(4, p_cl_M->Ax);
147  kernel_sparse_collect.setArg(5, p_cl_mask->Ax);
148  kernel_sparse_collect.setArg(6, cl_prodi);
149  kernel_sparse_collect.setArg(7, cl_prodx);
150  kernel_sparse_collect.setArg(8, cl_prods_offset.buffer());
151  kernel_sparse_collect.setArg(9, p_cl_v->values);
152 
153  cl::NDRange collect_global(m_block_size * n_groups_to_dispatch_v);
154  cl::NDRange collect_local(m_block_size);
155  CL_DISPATCH_PROFILED("collect", queue, kernel_sparse_collect, cl::NDRange(), collect_global, collect_local);
156 
157  CL_PROFILE_BEGIN("sort", queue)
158  const uint max_key = r->get_n_rows() - 1;
159  const uint n_elements = prods_count;
160  cl_sort_by_key<T>(queue, cl_prodi, cl_prodx, n_elements, p_tmp_alloc, max_key);
161  CL_PROFILE_END();
162 
163  uint reduced_size;
164  cl::Buffer reduced_keys;
165  cl::Buffer reduced_values;
166 
167  CL_PROFILE_BEGIN("reduce", queue)
168  cl_reduce_by_key(queue, cl_prodi, cl_prodx, prods_count, reduced_keys, reduced_values, reduced_size, op_add, p_tmp_alloc);
169  CL_PROFILE_END();
170 
171  p_cl_r->Ai = reduced_keys;
172  p_cl_r->Ax = reduced_values;
173  p_cl_r->values = reduced_size;
174 
175  p_tmp_alloc->free_all();
176 
177  return Status::Ok;
178  }
179 
180  bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
181  const ref_ptr<TOpBinary<T, T, T>>& op_add,
182  const ref_ptr<TOpSelect<T>>& op_select,
183  std::shared_ptr<CLProgram>& program) {
184  m_block_size = get_acc_cl()->get_default_wgs();
185  m_block_count = 1;
186 
187  assert(m_block_count >= 1);
188 
189  CLProgramBuilder program_builder;
190  program_builder
191  .set_name("vxm")
192  .add_define("BLOCK_SIZE", m_block_size)
193  .add_type("TYPE", get_ttype<T>().template as<Type>())
194  .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
195  .add_op("OP_BINARY2", op_add.template as<OpBinary>())
196  .add_op("OP_SELECT", op_select.template as<OpSelect>())
197  .set_source(source_vxm)
198  .acquire();
199 
200  program = program_builder.get_program();
201 
202  return true;
203  }
204 
205  private:
206  uint m_block_size = 0;
207  uint m_block_count = 0;
208  };
209 
210 }// namespace spla
211 
212 #endif//SPLA_CL_VXM_HPP
#define CL_PROFILE_END()
Definition: cl_debug.hpp:111
#define CL_COUNTER_SET(name, queue, counter, value)
Definition: cl_debug.hpp:87
#define CL_COUNTER_GET(name, queue, counter, value)
Definition: cl_debug.hpp:70
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition: cl_debug.hpp:36
#define CL_PROFILE_BEGIN(name, queue)
Definition: cl_debug.hpp:104
Status of library operation execution.
Definition: cl_vxm.hpp:56
std::string get_name() override
Definition: cl_vxm.hpp:60
~Algo_vxm_masked_cl() override=default
Status execute(const DispatchContext &ctx) override
Definition: cl_vxm.hpp:68
std::string get_description() override
Definition: cl_vxm.hpp:64
Definition: cl_counter.hpp:58
cl::Buffer & buffer()
Definition: cl_counter.cpp:59
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
#define LOG_MSG(status, msg)
Definition: logger.hpp:66
Definition: algorithm.hpp:37
void cl_reduce_by_key(cl::CommandQueue &queue, const cl::Buffer &keys, const cl::Buffer &values, const uint size, cl::Buffer &unique_keys, cl::Buffer &reduce_values, uint &reduced_size, const ref_ptr< TOpBinary< T, T, T >> &reduce_op, CLAlloc *tmp_alloc)
Definition: cl_reduce_by_key.hpp:43
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92