spla
Loading...
Searching...
No Matches
tmatrix.hpp
Go to the documentation of this file.
1/**********************************************************************************/
2/* This file is part of spla project */
3/* https://github.com/SparseLinearAlgebra/spla */
4/**********************************************************************************/
5/* MIT License */
6/* */
7/* Copyright (c) 2023 SparseLinearAlgebra */
8/* */
9/* Permission is hereby granted, free of charge, to any person obtaining a copy */
10/* of this software and associated documentation files (the "Software"), to deal */
11/* in the Software without restriction, including without limitation the rights */
12/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13/* copies of the Software, and to permit persons to whom the Software is */
14/* furnished to do so, subject to the following conditions: */
15/* */
16/* The above copyright notice and this permission notice shall be included in all */
17/* copies or substantial portions of the Software. */
18/* */
19/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25/* SOFTWARE. */
26/**********************************************************************************/
27
28#ifndef SPLA_TMATRIX_HPP
29#define SPLA_TMATRIX_HPP
30
31#include <spla/config.hpp>
32#include <spla/matrix.hpp>
33
34#include <core/logger.hpp>
35#include <core/tarray.hpp>
36#include <core/tdecoration.hpp>
37#include <core/top.hpp>
38#include <core/ttype.hpp>
39
42
43namespace spla {
44
56 template<typename T>
57 class TMatrix final : public Matrix {
58 public:
59 TMatrix(uint n_rows, uint n_cols);
60 ~TMatrix() override = default;
61
62 uint get_n_rows() override;
63 uint get_n_cols() override;
64 ref_ptr<Type> get_type() override;
65 void set_label(std::string label) override;
66 const std::string& get_label() const override;
67 Status set_format(FormatMatrix format) override;
68 Status set_fill_value(const ref_ptr<Scalar>& value) override;
69 Status set_reduce(ref_ptr<OpBinary> resolve_duplicates) override;
70 Status set_int(uint row_id, uint col_id, std::int32_t value) override;
71 Status set_uint(uint row_id, uint col_id, std::uint32_t value) override;
72 Status set_float(uint row_id, uint col_id, float value) override;
73 Status get_int(uint row_id, uint col_id, int32_t& value) override;
74 Status get_uint(uint row_id, uint col_id, uint32_t& value) override;
75 Status get_float(uint row_id, uint col_id, float& value) override;
76 Status build(const ref_ptr<MemView>& keys1, const ref_ptr<MemView>& keys2, const ref_ptr<MemView>& values) override;
77 Status read(ref_ptr<MemView>& keys1, ref_ptr<MemView>& keys2, ref_ptr<MemView>& values) override;
78 Status clear() override;
79
80 template<typename Decorator>
81 Decorator* get() { return m_storage.template get<Decorator>(); }
82
83 void validate_rw(FormatMatrix format);
84 void validate_rwd(FormatMatrix format);
85 void validate_wd(FormatMatrix format);
86 void validate_ctor(FormatMatrix format);
87 bool is_valid(FormatMatrix format) const;
88 T get_fill_value() const { return m_storage.get_fill_value(); }
89
91
92 private:
93 typename StorageManagerMatrix<T>::Storage m_storage;
94 std::string m_label;
95 };
96
97 template<typename T>
98 TMatrix<T>::TMatrix(uint n_rows, uint n_cols) {
99 m_storage.set_dims(n_rows, n_cols);
100 }
101
102 template<typename T>
104 return m_storage.get_n_rows();
105 }
106 template<typename T>
108 return m_storage.get_n_cols();
109 }
110 template<typename T>
112 return get_ttype<T>().template as<Type>();
113 }
114
115 template<typename T>
116 void TMatrix<T>::set_label(std::string label) {
117 m_label = std::move(label);
118 LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this);
119 }
120 template<typename T>
121 const std::string& TMatrix<T>::get_label() const {
122 return m_label;
123 }
124
125 template<typename T>
127 validate_rw(format);
128 return Status::Ok;
129 }
130 template<typename T>
132 if (value) {
133 m_storage.invalidate();
134
135 if constexpr (std::is_same<T, T_INT>::value) m_storage.set_fill_value(value->as_int());
136 if constexpr (std::is_same<T, T_UINT>::value) m_storage.set_fill_value(value->as_uint());
137 if constexpr (std::is_same<T, T_FLOAT>::value) m_storage.set_fill_value(value->as_float());
138
139 return Status::Ok;
140 }
141
143 }
144 template<typename T>
146 auto reduce = resolve_duplicates.template cast_safe<TOpBinary<T, T, T>>();
147
148 if (reduce) {
149 validate_ctor(FormatMatrix::CpuLil);
150 get<CpuLil<T>>()->reduce = reduce->function;
151 validate_ctor(FormatMatrix::CpuDok);
152 get<CpuDok<T>>()->reduce = reduce->function;
153 }
154
156 }
157
158 template<typename T>
159 Status TMatrix<T>::set_int(uint row_id, uint col_id, std::int32_t value) {
160 validate_rwd(FormatMatrix::CpuLil);
161 cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
162 return Status::Ok;
163 }
164 template<typename T>
165 Status TMatrix<T>::set_uint(uint row_id, uint col_id, std::uint32_t value) {
166 validate_rwd(FormatMatrix::CpuLil);
167 cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
168 return Status::Ok;
169 }
170 template<typename T>
171 Status TMatrix<T>::set_float(uint row_id, uint col_id, float value) {
172 validate_rwd(FormatMatrix::CpuLil);
173 cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
174 return Status::Ok;
175 }
176
177 template<typename T>
178 Status TMatrix<T>::get_int(uint row_id, uint col_id, int32_t& value) {
179 validate_rw(FormatMatrix::CpuDok);
180
181 auto& Ax = get<CpuDok<T>>()->Ax;
182 auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
183 value = m_storage.get_fill_value();
184
185 if (entry != Ax.end()) {
186 value = static_cast<int32_t>(entry->second);
187 }
188
189 return Status::Ok;
190 }
191 template<typename T>
192 Status TMatrix<T>::get_uint(uint row_id, uint col_id, uint32_t& value) {
193 validate_rw(FormatMatrix::CpuDok);
194
195 auto& Ax = get<CpuDok<T>>()->Ax;
196 auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
197 value = m_storage.get_fill_value();
198
199 if (entry != Ax.end()) {
200 value = static_cast<uint32_t>(entry->second);
201 }
202
203 return Status::Ok;
204 }
205 template<typename T>
206 Status TMatrix<T>::get_float(uint row_id, uint col_id, float& value) {
207 validate_rw(FormatMatrix::CpuDok);
208
209 auto& Ax = get<CpuDok<T>>()->Ax;
210 auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
211 value = m_storage.get_fill_value();
212
213 if (entry != Ax.end()) {
214 value = static_cast<float>(entry->second);
215 }
216
217 return Status::Ok;
218 }
219
220 template<typename T>
221 Status TMatrix<T>::build(const ref_ptr<MemView>& keys1, const ref_ptr<MemView>& keys2, const ref_ptr<MemView>& values) {
222 assert(keys1);
223 assert(keys2);
224 assert(values);
225
226 const auto key_size = sizeof(uint);
227 const auto value_size = sizeof(T);
228 const auto elements_count = keys1->get_size() / key_size;
229
230 if (elements_count != values->get_size() / value_size) {
232 }
233 if (elements_count * key_size != keys1->get_size()) {
235 }
236 if (elements_count * key_size != keys2->get_size()) {
238 }
239
240 validate_rwd(FormatMatrix::CpuCoo);
241 CpuCoo<T>& coo = *get<CpuCoo<T>>();
242
243 coo.Ai.resize(elements_count);
244 coo.Aj.resize(elements_count);
245 coo.Ax.resize(elements_count);
246 coo.values = uint(elements_count);
247
248 keys1->read(0, key_size * elements_count, coo.Ai.data());
249 keys2->read(0, key_size * elements_count, coo.Aj.data());
250 values->read(0, value_size * elements_count, coo.Ax.data());
251
252 return Status::Ok;
253 }
254 template<typename T>
256 const auto key_size = sizeof(uint);
257 const auto value_size = sizeof(T);
258
259 validate_rw(FormatMatrix::CpuCoo);
260 CpuCoo<T>& coo = *get<CpuCoo<T>>();
261
262 const auto elements_count = coo.Ai.size();
263
264 keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false);
265 keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false);
266 values = MemView::make(coo.Ax.data(), value_size * elements_count, false);
267
268 return Status::Ok;
269 }
270
271 template<typename T>
273 m_storage.invalidate();
274 return Status::Ok;
275 }
276
277 template<typename T>
279 StorageManagerMatrix<T>* manager = get_storage_manager();
280 manager->validate_rw(format, m_storage);
281 }
282
283 template<typename T>
285 StorageManagerMatrix<T>* manager = get_storage_manager();
286 manager->validate_rwd(format, m_storage);
287 }
288
289 template<typename T>
291 StorageManagerMatrix<T>* manager = get_storage_manager();
292 manager->validate_wd(format, m_storage);
293 }
294
295 template<typename T>
297 StorageManagerMatrix<T>* manager = get_storage_manager();
298 manager->validate_ctor(format, m_storage);
299 }
300
301 template<typename T>
303 return m_storage.is_valid(format);
304 }
305
306 template<typename T>
308 static std::unique_ptr<StorageManagerMatrix<T>> storage_manager;
309
310 if (!storage_manager) {
311 storage_manager = std::make_unique<StorageManagerMatrix<T>>();
312 register_formats_matrix(*storage_manager);
313 }
314
315 return storage_manager.get();
316 }
317
322}// namespace spla
323
324
325#endif//SPLA_TMATRIX_HPP
Named storage format for library matrix data objects.
Status of library operation execution.
CPU list of coordinates matrix format.
Definition cpu_formats.hpp:148
std::vector< uint > Aj
Definition cpu_formats.hpp:155
std::vector< T > Ax
Definition cpu_formats.hpp:156
std::vector< uint > Ai
Definition cpu_formats.hpp:154
std::pair< uint, uint > Key
Definition cpu_formats.hpp:134
CPU list-of-list matrix format for fast incremental build.
Definition cpu_formats.hpp:107
Generalized M x N dimensional matrix object.
Definition matrix.hpp:48
static ref_ptr< MemView > make()
Definition memview.cpp:67
General format converter for vector or matrix decoration storage.
Definition storage_manager.hpp:57
Storage for decorators with data of a particular vector or matrix object.
Definition tdecoration.hpp:70
void set_dims(uint n_rows, uint n_cols)
Definition tdecoration.hpp:92
uint values
Definition tdecoration.hpp:58
Matrix interface implementation with type information bound.
Definition tmatrix.hpp:57
~TMatrix() override=default
T get_fill_value() const
Definition tmatrix.hpp:88
Decorator * get()
Definition tmatrix.hpp:81
Automates reference counting and behaves as shared smart pointer.
Definition ref.hpp:117
ref_ptr< Type > get_type() override
Definition tmatrix.hpp:111
TMatrix(uint n_rows, uint n_cols)
Definition tmatrix.hpp:98
void validate_wd(FormatMatrix format)
Definition tmatrix.hpp:290
void validate_rwd(FormatMatrix format)
Definition tmatrix.hpp:284
bool is_valid(FormatMatrix format) const
Definition tmatrix.hpp:302
void validate_wd(F format, Storage &storage)
Definition storage_manager.hpp:212
static StorageManagerMatrix< T > * get_storage_manager()
Definition tmatrix.hpp:307
void validate_rw(F format, Storage &storage)
Definition storage_manager.hpp:128
Status get_int(uint row_id, uint col_id, int32_t &value) override
Definition tmatrix.hpp:178
const std::string & get_label() const override
Definition tmatrix.hpp:121
void validate_rw(FormatMatrix format)
Definition tmatrix.hpp:278
Status set_fill_value(const ref_ptr< Scalar > &value) override
Definition tmatrix.hpp:131
uint get_n_cols() override
Definition tmatrix.hpp:107
Status get_uint(uint row_id, uint col_id, uint32_t &value) override
Definition tmatrix.hpp:192
Status clear() override
Definition tmatrix.hpp:272
Status set_int(uint row_id, uint col_id, std::int32_t value) override
Definition tmatrix.hpp:159
Status set_uint(uint row_id, uint col_id, std::uint32_t value) override
Definition tmatrix.hpp:165
Status get_float(uint row_id, uint col_id, float &value) override
Definition tmatrix.hpp:206
Status read(ref_ptr< MemView > &keys1, ref_ptr< MemView > &keys2, ref_ptr< MemView > &values) override
Definition tmatrix.hpp:255
void set_label(std::string label) override
Definition tmatrix.hpp:116
void cpu_lil_add_element(uint row_id, uint col_id, T element, CpuLil< T > &lil)
Definition cpu_format_lil.hpp:55
void validate_ctor(F format, Storage &storage)
Definition storage_manager.hpp:121
void validate_rwd(F format, Storage &storage)
Definition storage_manager.hpp:206
Status set_float(uint row_id, uint col_id, float value) override
Definition tmatrix.hpp:171
Status set_format(FormatMatrix format) override
Definition tmatrix.hpp:126
uint get_n_rows() override
Definition tmatrix.hpp:103
void validate_ctor(FormatMatrix format)
Definition tmatrix.hpp:296
Status build(const ref_ptr< MemView > &keys1, const ref_ptr< MemView > &keys2, const ref_ptr< MemView > &values) override
Definition tmatrix.hpp:221
Status set_reduce(ref_ptr< OpBinary > resolve_duplicates) override
Definition tmatrix.hpp:145
std::uint32_t uint
Library index and size type.
Definition config.hpp:56
#define LOG_MSG(status, msg)
Definition logger.hpp:66
Definition algorithm.hpp:37
void register_formats_matrix(StorageManagerMatrix< T > &manager)
Definition storage_manager_matrix.hpp:51