spla
cl_mxmT_masked.hpp
Go to the documentation of this file.
1 /**********************************************************************************/
2 /* This file is part of spla project */
3 /* https://github.com/SparseLinearAlgebra/spla */
4 /**********************************************************************************/
5 /* MIT License */
6 /* */
7 /* Copyright (c) 2023 SparseLinearAlgebra */
8 /* */
9 /* Permission is hereby granted, free of charge, to any person obtaining a copy */
10 /* of this software and associated documentation files (the "Software"), to deal */
11 /* in the Software without restriction, including without limitation the rights */
12 /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13 /* copies of the Software, and to permit persons to whom the Software is */
14 /* furnished to do so, subject to the following conditions: */
15 /* */
16 /* The above copyright notice and this permission notice shall be included in all */
17 /* copies or substantial portions of the Software. */
18 /* */
19 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20 /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21 /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22 /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23 /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24 /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25 /* SOFTWARE. */
26 /**********************************************************************************/
27 
28 #ifndef SPLA_CL_MXMT_MASKED_HPP
29 #define SPLA_CL_MXMT_MASKED_HPP
30 
32 
33 #include <core/dispatcher.hpp>
34 #include <core/registry.hpp>
35 #include <core/tmatrix.hpp>
36 #include <core/top.hpp>
37 #include <core/tscalar.hpp>
38 #include <core/ttype.hpp>
39 #include <core/tvector.hpp>
40 
41 #include <opencl/cl_counter.hpp>
42 #include <opencl/cl_debug.hpp>
43 #include <opencl/cl_formats.hpp>
46 
47 #include <algorithm>
48 #include <sstream>
49 
50 namespace spla {
51 
52  template<typename T>
53  class Algo_mxmT_masked_cl final : public RegistryAlgo {
54  public:
55  ~Algo_mxmT_masked_cl() override = default;
56 
57  std::string get_name() override {
58  return "mxmT_masked";
59  }
60 
61  std::string get_description() override {
62  return "parallel masked matrix matrix-transposed product on opencl device";
63  }
64 
65  Status execute(const DispatchContext& ctx) override {
66  TIME_PROFILE_SCOPE("opencl/mxmT_masked");
67 
68  auto t = ctx.task.template cast_safe<ScheduleTask_mxmT_masked>();
69 
70  ref_ptr<TMatrix<T>> R = t->R.template cast_safe<TMatrix<T>>();
71  ref_ptr<TMatrix<T>> mask = t->mask.template cast_safe<TMatrix<T>>();
72  ref_ptr<TMatrix<T>> A = t->A.template cast_safe<TMatrix<T>>();
73  ref_ptr<TMatrix<T>> B = t->B.template cast_safe<TMatrix<T>>();
74  ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
75  ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
76  ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
77  ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
78 
79  R->validate_wd(FormatMatrix::AccCsr);
80  mask->validate_rw(FormatMatrix::AccCsr);
81  A->validate_rw(FormatMatrix::AccCsr);
82  B->validate_rw(FormatMatrix::AccCsr);
83 
84  std::shared_ptr<CLProgram> program;
85  if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
86 
87  auto* p_cl_R = R->template get<CLCsr<T>>();
88  const auto* p_cl_mask = mask->template get<CLCsr<T>>();
89  const auto* p_cl_A = A->template get<CLCsr<T>>();
90  const auto* p_cl_B = B->template get<CLCsr<T>>();
91 
92  if (p_cl_mask->values == 0) {
93  return Status::Ok;
94  }
95  if (p_cl_A->values == 0) {
96  return Status::Ok;
97  }
98  if (p_cl_B->values == 0) {
99  return Status::Ok;
100  }
101 
102  auto* p_cl_acc = get_acc_cl();
103  auto& queue = p_cl_acc->get_queue_default();
104 
105  cl_csr_resize<T>(R->get_n_rows(), p_cl_mask->values, *p_cl_R);
106  queue.enqueueCopyBuffer(p_cl_mask->Ap, p_cl_R->Ap, 0, 0, sizeof(uint) * (R->get_n_rows() + 1));
107  queue.enqueueCopyBuffer(p_cl_mask->Aj, p_cl_R->Aj, 0, 0, sizeof(uint) * (p_cl_R->values));
108 
109  auto kernel = program->make_kernel("mxmT_masked_csr_scalar");
110  kernel.setArg(0, p_cl_A->Ap);
111  kernel.setArg(1, p_cl_A->Aj);
112  kernel.setArg(2, p_cl_A->Ax);
113  kernel.setArg(3, p_cl_B->Ap);
114  kernel.setArg(4, p_cl_B->Aj);
115  kernel.setArg(5, p_cl_B->Ax);
116  kernel.setArg(6, p_cl_mask->Ap);
117  kernel.setArg(7, p_cl_mask->Aj);
118  kernel.setArg(8, p_cl_mask->Ax);
119  kernel.setArg(9, p_cl_R->Ax);
120  kernel.setArg(10, T(init->get_value()));
121  kernel.setArg(11, R->get_n_rows());
122 
123  uint n_groups_to_dispatch = div_up_clamp(R->get_n_rows(), m_block_count, 1, 1024);
124 
125  cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
126  cl::NDRange exec_local(m_block_count, m_block_size);
127  CL_DISPATCH_PROFILED("exec", queue, kernel, cl::NDRange(), exec_global, exec_local);
128 
129  return Status::Ok;
130  }
131 
132  bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
133  const ref_ptr<TOpBinary<T, T, T>>& op_add,
134  const ref_ptr<TOpSelect<T>>& op_select,
135  std::shared_ptr<CLProgram>& program) {
136  m_block_size = get_acc_cl()->get_default_wgs();
137  m_block_count = 1;
138 
139  assert(m_block_count >= 1);
140 
141  CLProgramBuilder program_builder;
142  program_builder
143  .set_name("mxmT_masked")
144  .add_define("WARP_SIZE", get_acc_cl()->get_wave_size())
145  .add_define("BLOCK_SIZE", m_block_size)
146  .add_define("BLOCK_COUNT", m_block_count)
147  .add_type("TYPE", get_ttype<T>().template as<Type>())
148  .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
149  .add_op("OP_BINARY2", op_add.template as<OpBinary>())
150  .add_op("OP_SELECT", op_select.template as<OpSelect>())
151  .set_source(source_mxmT_masked)
152  .acquire();
153 
154  program = program_builder.get_program();
155 
156  return true;
157  }
158 
159  private:
160  uint m_block_size = 0;
161  uint m_block_count = 0;
162  };
163 
164 }// namespace spla
165 
166 #endif//SPLA_CL_MXMT_MASKED_HPP
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition: cl_debug.hpp:36
Status of library operation execution.
Definition: cl_mxmT_masked.hpp:53
std::string get_name() override
Definition: cl_mxmT_masked.hpp:57
Status execute(const DispatchContext &ctx) override
Definition: cl_mxmT_masked.hpp:65
bool ensure_kernel(const ref_ptr< TOpBinary< T, T, T >> &op_multiply, const ref_ptr< TOpBinary< T, T, T >> &op_add, const ref_ptr< TOpSelect< T >> &op_select, std::shared_ptr< CLProgram > &program)
Definition: cl_mxmT_masked.hpp:132
~Algo_mxmT_masked_cl() override=default
std::string get_description() override
Definition: cl_mxmT_masked.hpp:61
Runtime opencl program builder.
Definition: cl_program_builder.hpp:55
CLProgramBuilder & add_op(const char *name, const ref_ptr< OpUnary > &op)
Definition: cl_program_builder.cpp:49
const std::shared_ptr< CLProgram > & get_program()
Definition: cl_program_builder.hpp:66
CLProgramBuilder & add_define(const char *define, int value)
Definition: cl_program_builder.cpp:41
CLProgramBuilder & set_name(const char *name)
Definition: cl_program_builder.cpp:37
CLProgramBuilder & add_type(const char *alias, const ref_ptr< Type > &type)
Definition: cl_program_builder.cpp:45
CLProgramBuilder & set_source(const char *source)
Definition: cl_program_builder.cpp:61
void acquire()
Definition: cl_program_builder.cpp:65
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
Definition: top.hpp:174
Definition: top.hpp:228
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
Definition: algorithm.hpp:37
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92