spla
Loading...
Searching...
No Matches
cl_mxmT_masked.hpp
Go to the documentation of this file.
1/**********************************************************************************/
2/* This file is part of spla project */
3/* https://github.com/SparseLinearAlgebra/spla */
4/**********************************************************************************/
5/* MIT License */
6/* */
7/* Copyright (c) 2023 SparseLinearAlgebra */
8/* */
9/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10/* of this software and associated documentation files (the "Software"), to deal */
11/* in the Software without restriction, including without limitation the rights */
12/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13/* copies of the Software, and to permit persons to whom the Software is */
14/* furnished to do so, subject to the following conditions: */
15/* */
16/* The above copyright notice and this permission notice shall be included in all */
17/* copies or substantial portions of the Software. */
18/* */
19/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25/* SOFTWARE. */
26/**********************************************************************************/
27
28#ifndef SPLA_CL_MXMT_MASKED_HPP
29#define SPLA_CL_MXMT_MASKED_HPP
30
32
33#include <core/dispatcher.hpp>
34#include <core/registry.hpp>
35#include <core/tmatrix.hpp>
36#include <core/top.hpp>
37#include <core/tscalar.hpp>
38#include <core/ttype.hpp>
39#include <core/tvector.hpp>
40
41#include <opencl/cl_counter.hpp>
42#include <opencl/cl_debug.hpp>
43#include <opencl/cl_formats.hpp>
46
47#include <algorithm>
48#include <sstream>
49
50namespace spla {
51
52 template<typename T>
53 class Algo_mxmT_masked_cl final : public RegistryAlgo {
54 public:
55 ~Algo_mxmT_masked_cl() override = default;
56
57 std::string get_name() override {
58 return "mxmT_masked";
59 }
60
61 std::string get_description() override {
62 return "parallel masked matrix matrix-transposed product on opencl device";
63 }
64
65 Status execute(const DispatchContext& ctx) override {
66 TIME_PROFILE_SCOPE("opencl/mxmT_masked");
67
68 auto t = ctx.task.template cast_safe<ScheduleTask_mxmT_masked>();
69
70 ref_ptr<TMatrix<T>> R = t->R.template cast_safe<TMatrix<T>>();
71 ref_ptr<TMatrix<T>> mask = t->mask.template cast_safe<TMatrix<T>>();
72 ref_ptr<TMatrix<T>> A = t->A.template cast_safe<TMatrix<T>>();
73 ref_ptr<TMatrix<T>> B = t->B.template cast_safe<TMatrix<T>>();
74 ref_ptr<TOpBinary<T, T, T>> op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
75 ref_ptr<TOpBinary<T, T, T>> op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
76 ref_ptr<TOpSelect<T>> op_select = t->op_select.template cast_safe<TOpSelect<T>>();
77 ref_ptr<TScalar<T>> init = t->init.template cast_safe<TScalar<T>>();
78
79 R->validate_wd(FormatMatrix::AccCsr);
80 mask->validate_rw(FormatMatrix::AccCsr);
81 A->validate_rw(FormatMatrix::AccCsr);
82 B->validate_rw(FormatMatrix::AccCsr);
83
84 std::shared_ptr<CLProgram> program;
85 if (!ensure_kernel(op_multiply, op_add, op_select, program)) return Status::CompilationError;
86
87 auto* p_cl_R = R->template get<CLCsr<T>>();
88 const auto* p_cl_mask = mask->template get<CLCsr<T>>();
89 const auto* p_cl_A = A->template get<CLCsr<T>>();
90 const auto* p_cl_B = B->template get<CLCsr<T>>();
91
92 if (p_cl_mask->values == 0) {
93 return Status::Ok;
94 }
95 if (p_cl_A->values == 0) {
96 return Status::Ok;
97 }
98 if (p_cl_B->values == 0) {
99 return Status::Ok;
100 }
101
102 auto* p_cl_acc = get_acc_cl();
103 auto& queue = p_cl_acc->get_queue_default();
104
105 cl_csr_resize<T>(R->get_n_rows(), p_cl_mask->values, *p_cl_R);
106 queue.enqueueCopyBuffer(p_cl_mask->Ap, p_cl_R->Ap, 0, 0, sizeof(uint) * (R->get_n_rows() + 1));
107 queue.enqueueCopyBuffer(p_cl_mask->Aj, p_cl_R->Aj, 0, 0, sizeof(uint) * (p_cl_R->values));
108
109 auto kernel = program->make_kernel("mxmT_masked_csr_scalar");
110 kernel.setArg(0, p_cl_A->Ap);
111 kernel.setArg(1, p_cl_A->Aj);
112 kernel.setArg(2, p_cl_A->Ax);
113 kernel.setArg(3, p_cl_B->Ap);
114 kernel.setArg(4, p_cl_B->Aj);
115 kernel.setArg(5, p_cl_B->Ax);
116 kernel.setArg(6, p_cl_mask->Ap);
117 kernel.setArg(7, p_cl_mask->Aj);
118 kernel.setArg(8, p_cl_mask->Ax);
119 kernel.setArg(9, p_cl_R->Ax);
120 kernel.setArg(10, T(init->get_value()));
121 kernel.setArg(11, R->get_n_rows());
122
123 uint n_groups_to_dispatch = div_up_clamp(R->get_n_rows(), m_block_count, 1, 1024);
124
125 cl::NDRange exec_global(m_block_count * n_groups_to_dispatch, m_block_size);
126 cl::NDRange exec_local(m_block_count, m_block_size);
127 CL_DISPATCH_PROFILED("exec", queue, kernel, cl::NDRange(), exec_global, exec_local);
128
129 return Status::Ok;
130 }
131
132 bool ensure_kernel(const ref_ptr<TOpBinary<T, T, T>>& op_multiply,
133 const ref_ptr<TOpBinary<T, T, T>>& op_add,
134 const ref_ptr<TOpSelect<T>>& op_select,
135 std::shared_ptr<CLProgram>& program) {
136 m_block_size = get_acc_cl()->get_default_wgs();
137 m_block_count = 1;
138
139 assert(m_block_count >= 1);
140
141 CLProgramBuilder program_builder;
142 program_builder
143 .set_name("mxmT_masked")
144 .add_define("WARP_SIZE", get_acc_cl()->get_wave_size())
145 .add_define("BLOCK_SIZE", m_block_size)
146 .add_define("BLOCK_COUNT", m_block_count)
147 .add_type("TYPE", get_ttype<T>().template as<Type>())
148 .add_op("OP_BINARY1", op_multiply.template as<OpBinary>())
149 .add_op("OP_BINARY2", op_add.template as<OpBinary>())
150 .add_op("OP_SELECT", op_select.template as<OpSelect>())
151 .set_source(source_mxmT_masked)
152 .acquire();
153
154 program = program_builder.get_program();
155
156 return true;
157 }
158
159 private:
160 uint m_block_size = 0;
161 uint m_block_count = 0;
162 };
163
164}// namespace spla
165
166#endif//SPLA_CL_MXMT_MASKED_HPP
#define CL_DISPATCH_PROFILED(name, queue, kernel,...)
Definition cl_debug.hpp:36
Status of library operation execution.
Definition cl_mxmT_masked.hpp:53
std::string get_name() override
Definition cl_mxmT_masked.hpp:57
Status execute(const DispatchContext &ctx) override
Definition cl_mxmT_masked.hpp:65
~Algo_mxmT_masked_cl() override=default
bool ensure_kernel(const ref_ptr< TOpBinary< T, T, T > > &op_multiply, const ref_ptr< TOpBinary< T, T, T > > &op_add, const ref_ptr< TOpSelect< T > > &op_select, std::shared_ptr< CLProgram > &program)
Definition cl_mxmT_masked.hpp:132
std::string get_description() override
Definition cl_mxmT_masked.hpp:61
Runtime opencl program builder.
Definition cl_program_builder.hpp:55
CLProgramBuilder & add_op(const char *name, const ref_ptr< OpUnary > &op)
Definition cl_program_builder.cpp:49
CLProgramBuilder & add_define(const char *define, int value)
Definition cl_program_builder.cpp:41
const std::shared_ptr< CLProgram > & get_program()
Definition cl_program_builder.hpp:66
CLProgramBuilder & set_name(const char *name)
Definition cl_program_builder.cpp:37
CLProgramBuilder & add_type(const char *alias, const ref_ptr< Type > &type)
Definition cl_program_builder.cpp:45
CLProgramBuilder & set_source(const char *source)
Definition cl_program_builder.cpp:61
void acquire()
Definition cl_program_builder.cpp:65
Algorithm suitable to process schedule task based on task string key.
Definition registry.hpp:66
Definition top.hpp:174
Definition top.hpp:228
Automates reference counting and behaves as shared smart pointer.
Definition ref.hpp:117
void cl_csr_resize(std::size_t n_rows, std::size_t n_values, CLCsr< T > &storage)
Definition cl_format_csr.hpp:62
std::uint32_t uint
Library index and size type.
Definition config.hpp:56
Definition algorithm.hpp:37
Execution context of a single task.
Definition dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition time_profiler.hpp:92