spla
Loading...
Searching...
No Matches
cl_vxm.hpp
Go to the documentation of this file.
1/**********************************************************************************/
2/* This file is part of spla project */
3/* https://github.com/SparseLinearAlgebra/spla */
4/**********************************************************************************/
5/* MIT License */
6/* */
7/* Copyright (c) 2023 SparseLinearAlgebra */
8/* */
9/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10/* of this software and associated documentation files (the "Software"), to deal */
11/* in the Software without restriction, including without limitation the rights */
12/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13/* copies of the Software, and to permit persons to whom the Software is */
14/* furnished to do so, subject to the following conditions: */
15/* */
16/* The above copyright notice and this permission notice shall be included in all */
17/* copies or substantial portions of the Software. */
18/* */
19/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25/* SOFTWARE. */
26/**********************************************************************************/
27
28#ifndef SPLA_CL_VXM_HPP
29#define SPLA_CL_VXM_HPP
30
32
33#include <core/dispatcher.hpp>
34#include <core/registry.hpp>
35#include <core/tmatrix.hpp>
36#include <core/top.hpp>
37#include <core/tscalar.hpp>
38#include <core/ttype.hpp>
39#include <core/tvector.hpp>
40
42#include <opencl/cl_counter.hpp>
43#include <opencl/cl_debug.hpp>
44#include <opencl/cl_formats.hpp>
49
50#include <algorithm>
51#include <sstream>
52
53namespace spla {
54
55 template<typename T>
56 class Algo_vxm_masked_cl final : public RegistryAlgo {
57 public:
58 ~Algo_vxm_masked_cl() override = default;
59
60 std::string get_name() override {
61 return "vxm_masked";
62 }
63
64 std::string get_description() override {
65 return "parallel vector-matrix masked product on opencl device";
66 }
67
68 Status execute(const DispatchContext& ctx) override {
69 return execute_sparse(ctx);
70 }
71
72 private:
73 Status execute_sparse(const DispatchContext& ctx) {
74 TIME_PROFILE_SCOPE("opencl/vxm/sparse");
75
76 auto t = ctx.task.template cast_safe<ScheduleTask_vxm_masked>();
77
78 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
79 ref_ptr<TVector<T>> mask = t->mask.template cast_safe<TVector<T>>();
80 ref_ptr<TVector<T>> v = t->v.template cast_safe<TVector<T>>();
81 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
82 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
83 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
84 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
85 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
86
87 r->validate_wd(FormatVector::AccCoo);
88 mask->validate_rw(FormatVector::AccDense);
89 M->validate_rw(FormatMatrix::AccCsr);
90 v->validate_rw(FormatVector::AccCoo);
91 std::shared_ptr<CLProgram> program;
92 if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
93
94 auto* p_cl_r = r->template get<CLCooVec<T>>();
95 auto* p_cl_mask = mask->template get<CLDenseVec<T>>();
96 auto* p_cl_M = M->template get<CLCsr<T>>();
97 auto* p_cl_v = v->template get<CLCooVec<T>>();
98
99 auto* p_cl_acc = get_acc_cl();
100 auto* p_tmp_alloc = p_cl_acc->get_alloc_tmp();
101 auto& queue = p_cl_acc->get_queue_default();
102
103 uint prods_count;
104 CLCounterWrapper cl_prods_count;
105 CL_COUNTER_SET("init-prods-cnt", queue, cl_prods_count, 0);
106
107 auto kernel_sparse_count = program->make_kernel("vxm_sparse_count");
108 kernel_sparse_count.setArg(0, p_cl_v->Ai);
109 kernel_sparse_count.setArg(1, p_cl_v->Ax);
110 kernel_sparse_count.setArg(2, p_cl_M->Ap);
111 kernel_sparse_count.setArg(3, p_cl_M->Aj);
112 kernel_sparse_count.setArg(4, p_cl_mask->Ax);
113 kernel_sparse_count.setArg(5, cl_prods_count.buffer());
114 kernel_sparse_count.setArg(6, p_cl_v->values);
115
116 uint n_groups_to_dispatch_v = div_up_clamp(p_cl_v->values, m_block_size, 1, 1024);
117
118 cl::NDRange count_global(m_block_size * n_groups_to_dispatch_v);
119 cl::NDRange count_local(m_block_size);
120 CL_DISPATCH_PROFILED("count", queue, kernel_sparse_count, cl::NDRange(), count_global, count_local);
121
122 CL_COUNTER_GET("copy_prods_count", queue, cl_prods_count, prods_count);
123 LOG_MSG(Status::Ok, "temporary vi * A[,*] count " << prods_count);
124
125 if (prods_count == 0) {
126 LOG_MSG(Status::Ok, "nothing to do");
127
128 p_cl_r->Ai = cl::Buffer();
129 p_cl_r->Ax = cl::Buffer();
130 p_cl_r->values = 0;
131
132 return Status::Ok;
133 }
134
135 cl::Buffer cl_prodi = p_tmp_alloc->alloc(prods_count * sizeof(uint));
136 cl::Buffer cl_prodx = p_tmp_alloc->alloc(prods_count * sizeof(T));
137
138 CLCounterWrapper cl_prods_offset;
139 CL_COUNTER_SET("init-offsets-cnt", queue, cl_prods_offset, 0);
140
141 auto kernel_sparse_collect = program->make_kernel("vxm_sparse_collect");
142 kernel_sparse_collect.setArg(0, p_cl_v->Ai);
143 kernel_sparse_collect.setArg(1, p_cl_v->Ax);
144 kernel_sparse_collect.setArg(2, p_cl_M->Ap);
145 kernel_sparse_collect.setArg(3, p_cl_M->Aj);
146 kernel_sparse_collect.setArg(4, p_cl_M->Ax);
147 kernel_sparse_collect.setArg(5, p_cl_mask->Ax);
148 kernel_sparse_collect.setArg(6, cl_prodi);
149 kernel_sparse_collect.setArg(7, cl_prodx);
150 kernel_sparse_collect.setArg(8, cl_prods_offset.buffer());
151 kernel_sparse_collect.setArg(9, p_cl_v->values);
152
153 cl::NDRange collect_global(m_block_size * n_groups_to_dispatch_v);
154 cl::NDRange collect_local(m_block_size);
155 CL_DISPATCH_PROFILED("collect", queue, kernel_sparse_collect, cl::NDRange(), collect_global, collect_local);
156
157 CL_PROFILE_BEGIN("sort", queue)
158 const uint max_key = r->get_n_rows() - 1;
159 const uint n_elements = prods_count;
160 cl_sort_by_key<T>(queue, cl_prodi, cl_prodx, n_elements, p_tmp_alloc, max_key);
162
163 uint reduced_size;
164 cl::Buffer reduced_keys;
165 cl::Buffer reduced_values;
166
167 CL_PROFILE_BEGIN("reduce", queue)
168 cl_reduce_by_key(queue, cl_prodi, cl_prodx, prods_count, reduced_keys, reduced_values, reduced_size, op_add, p_tmp_alloc);
170
171 p_cl_r->Ai = reduced_keys;
172 p_cl_r->Ax = reduced_values;
173 p_cl_r->values = reduced_size;
174
175 p_tmp_alloc->free_all();
176
177 return Status::Ok;
178 }
179
180 bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
181 const ref_ptr<TOpBinary<T, T, T>>& op_add,
182 const ref_ptr<TOpSelect<T>>& op_select,
183 std::shared_ptr<CLProgram>& program) {
184 m_block_size = get_acc_cl()->get_default_wgs();
185 m_block_count = 1;
186
187 assert(m_block_count >= 1);
188
189 CLProgramBuilder program_builder;
190 program_builder
191 .set_name("vxm")
192 .add_define("BLOCK_SIZE", m_block_size)
193 .add_type("TYPE", get_ttype<T>().template as<Type>())
194 .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
195 .add_op("OP_BINARY2", op_add.template as<OpBinary>())
196 .add_op("OP_SELECT", op_select.template as<OpSelect>())
197 .set_source(source_vxm)
198 .acquire();
199
200 program = program_builder.get_program();
201
202 return true;
203 }
204
205 private:
206 uint m_block_size = 0;
207 uint m_block_count = 0;
208 };
209
210}// namespace spla
211
212#endif//SPLA_CL_VXM_HPP
#define CL_PROFILE_END()
Definition cl_debug.hpp:111
#define CL_COUNTER_SET(name, queue, counter, value)
Definition cl_debug.hpp:87
#define CL_COUNTER_GET(name, queue, counter, value)
Definition cl_debug.hpp:70
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition cl_debug.hpp:36
#define CL_PROFILE_BEGIN(name, queue)
Definition cl_debug.hpp:104
Status of library operation execution.
Definition cl_vxm.hpp:56
std::string get_name() override
Definition cl_vxm.hpp:60
~Algo_vxm_masked_cl() override=default
Status execute(const DispatchContext &ctx) override
Definition cl_vxm.hpp:68
std::string get_description() override
Definition cl_vxm.hpp:64
Definition cl_counter.hpp:58
cl::Buffer & buffer()
Definition cl_counter.cpp:59
Algorithm suitable to process schedule task based on task string key.
Definition registry.hpp:66
Automates reference counting and behaves as shared smart pointer.
Definition ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition config.hpp:56
#define LOG_MSG(status, msg)
Definition logger.hpp:66
Definition algorithm.hpp:37
void cl_sort_by_key(cl::CommandQueue &queue, cl::Buffer &keys, cl::Buffer &values, uint n, CLAlloc *tmp_alloc, uint max_key=0xffffffff)
Definition cl_sort_by_key.hpp:175
void cl_reduce_by_key(cl::CommandQueue &queue, const cl::Buffer &keys, const cl::Buffer &values, const uint size, cl::Buffer &unique_keys, cl::Buffer &reduce_values, uint &reduced_size, const ref_ptr< TOpBinary< T, T, T > > &reduce_op, CLAlloc *tmp_alloc)
Definition cl_reduce_by_key.hpp:43
Execution context of a single task.
Definition dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition time_profiler.hpp:92