spla
tmatrix.hpp
Go to the documentation of this file.
1 /**********************************************************************************/
2 /* This file is part of spla project */
3 /* https://github.com/SparseLinearAlgebra/spla */
4 /**********************************************************************************/
5 /* MIT License */
6 /* */
7 /* Copyright (c) 2023 SparseLinearAlgebra */
8 /* */
9 /* Permission is hereby granted, free of charge, to any person obtaining a copy */
10 /* of this software and associated documentation files (the "Software"), to deal */
11 /* in the Software without restriction, including without limitation the rights */
12 /* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */
13 /* copies of the Software, and to permit persons to whom the Software is */
14 /* furnished to do so, subject to the following conditions: */
15 /* */
16 /* The above copyright notice and this permission notice shall be included in all */
17 /* copies or substantial portions of the Software. */
18 /* */
19 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */
20 /* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */
21 /* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */
22 /* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */
23 /* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */
24 /* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */
25 /* SOFTWARE. */
26 /**********************************************************************************/
27 
28 #ifndef SPLA_TMATRIX_HPP
29 #define SPLA_TMATRIX_HPP
30 
31 #include <spla/config.hpp>
32 #include <spla/matrix.hpp>
33 
34 #include <core/logger.hpp>
35 #include <core/tarray.hpp>
36 #include <core/tdecoration.hpp>
37 #include <core/top.hpp>
38 #include <core/ttype.hpp>
39 
42 
43 namespace spla {
44 
56  template<typename T>
57  class TMatrix final : public Matrix {
58  public:
59  TMatrix(uint n_rows, uint n_cols);
60  ~TMatrix() override = default;
61 
62  uint get_n_rows() override;
63  uint get_n_cols() override;
64  ref_ptr<Type> get_type() override;
65  void set_label(std::string label) override;
66  const std::string& get_label() const override;
67  Status set_format(FormatMatrix format) override;
68  Status set_fill_value(const ref_ptr<Scalar>& value) override;
69  Status set_reduce(ref_ptr<OpBinary> resolve_duplicates) override;
70  Status set_int(uint row_id, uint col_id, std::int32_t value) override;
71  Status set_uint(uint row_id, uint col_id, std::uint32_t value) override;
72  Status set_float(uint row_id, uint col_id, float value) override;
73  Status get_int(uint row_id, uint col_id, int32_t& value) override;
74  Status get_uint(uint row_id, uint col_id, uint32_t& value) override;
75  Status get_float(uint row_id, uint col_id, float& value) override;
76  Status build(const ref_ptr<MemView>& keys1, const ref_ptr<MemView>& keys2, const ref_ptr<MemView>& values) override;
77  Status read(ref_ptr<MemView>& keys1, ref_ptr<MemView>& keys2, ref_ptr<MemView>& values) override;
78  Status clear() override;
79 
80  template<typename Decorator>
81  Decorator* get() { return m_storage.template get<Decorator>(); }
82 
83  void validate_rw(FormatMatrix format);
84  void validate_rwd(FormatMatrix format);
85  void validate_wd(FormatMatrix format);
86  void validate_ctor(FormatMatrix format);
87  bool is_valid(FormatMatrix format) const;
88  T get_fill_value() const { return m_storage.get_fill_value(); }
89 
91 
92  private:
93  typename StorageManagerMatrix<T>::Storage m_storage;
94  std::string m_label;
95  };
96 
97  template<typename T>
98  TMatrix<T>::TMatrix(uint n_rows, uint n_cols) {
99  m_storage.set_dims(n_rows, n_cols);
100  }
101 
102  template<typename T>
104  return m_storage.get_n_rows();
105  }
106  template<typename T>
108  return m_storage.get_n_cols();
109  }
110  template<typename T>
112  return get_ttype<T>().template as<Type>();
113  }
114 
115  template<typename T>
116  void TMatrix<T>::set_label(std::string label) {
117  m_label = std::move(label);
118  LOG_MSG(Status::Ok, "set label '" << m_label << "' to " << (void*) this);
119  }
120  template<typename T>
121  const std::string& TMatrix<T>::get_label() const {
122  return m_label;
123  }
124 
125  template<typename T>
127  validate_rw(format);
128  return Status::Ok;
129  }
130  template<typename T>
132  if (value) {
133  m_storage.invalidate();
134 
135  if constexpr (std::is_same<T, T_INT>::value) m_storage.set_fill_value(value->as_int());
136  if constexpr (std::is_same<T, T_UINT>::value) m_storage.set_fill_value(value->as_uint());
137  if constexpr (std::is_same<T, T_FLOAT>::value) m_storage.set_fill_value(value->as_float());
138 
139  return Status::Ok;
140  }
141 
143  }
144  template<typename T>
146  auto reduce = resolve_duplicates.template cast_safe<TOpBinary<T, T, T>>();
147 
148  if (reduce) {
149  validate_ctor(FormatMatrix::CpuLil);
150  get<CpuLil<T>>()->reduce = reduce->function;
151  validate_ctor(FormatMatrix::CpuDok);
152  get<CpuDok<T>>()->reduce = reduce->function;
153  }
154 
156  }
157 
158  template<typename T>
159  Status TMatrix<T>::set_int(uint row_id, uint col_id, std::int32_t value) {
160  validate_rwd(FormatMatrix::CpuLil);
161  cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
162  return Status::Ok;
163  }
164  template<typename T>
165  Status TMatrix<T>::set_uint(uint row_id, uint col_id, std::uint32_t value) {
166  validate_rwd(FormatMatrix::CpuLil);
167  cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
168  return Status::Ok;
169  }
170  template<typename T>
171  Status TMatrix<T>::set_float(uint row_id, uint col_id, float value) {
172  validate_rwd(FormatMatrix::CpuLil);
173  cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
174  return Status::Ok;
175  }
176 
177  template<typename T>
178  Status TMatrix<T>::get_int(uint row_id, uint col_id, int32_t& value) {
179  validate_rw(FormatMatrix::CpuDok);
180 
181  auto& Ax = get<CpuDok<T>>()->Ax;
182  auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
183  value = m_storage.get_fill_value();
184 
185  if (entry != Ax.end()) {
186  value = static_cast<int32_t>(entry->second);
187  }
188 
189  return Status::Ok;
190  }
191  template<typename T>
192  Status TMatrix<T>::get_uint(uint row_id, uint col_id, uint32_t& value) {
193  validate_rw(FormatMatrix::CpuDok);
194 
195  auto& Ax = get<CpuDok<T>>()->Ax;
196  auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
197  value = m_storage.get_fill_value();
198 
199  if (entry != Ax.end()) {
200  value = static_cast<uint32_t>(entry->second);
201  }
202 
203  return Status::Ok;
204  }
205  template<typename T>
206  Status TMatrix<T>::get_float(uint row_id, uint col_id, float& value) {
207  validate_rw(FormatMatrix::CpuDok);
208 
209  auto& Ax = get<CpuDok<T>>()->Ax;
210  auto entry = Ax.find(typename CpuDok<T>::Key(row_id, col_id));
211  value = m_storage.get_fill_value();
212 
213  if (entry != Ax.end()) {
214  value = static_cast<float>(entry->second);
215  }
216 
217  return Status::Ok;
218  }
219 
220  template<typename T>
221  Status TMatrix<T>::build(const ref_ptr<MemView>& keys1, const ref_ptr<MemView>& keys2, const ref_ptr<MemView>& values) {
222  assert(keys1);
223  assert(keys2);
224  assert(values);
225 
226  const auto key_size = sizeof(uint);
227  const auto value_size = sizeof(T);
228  const auto elements_count = keys1->get_size() / key_size;
229 
230  if (elements_count != values->get_size() / value_size) {
232  }
233  if (elements_count * key_size != keys1->get_size()) {
235  }
236  if (elements_count * key_size != keys2->get_size()) {
238  }
239 
240  validate_rwd(FormatMatrix::CpuCoo);
241  CpuCoo<T>& coo = *get<CpuCoo<T>>();
242 
243  coo.Ai.resize(elements_count);
244  coo.Aj.resize(elements_count);
245  coo.Ax.resize(elements_count);
246  coo.values = uint(elements_count);
247 
248  keys1->read(0, key_size * elements_count, coo.Ai.data());
249  keys2->read(0, key_size * elements_count, coo.Aj.data());
250  values->read(0, value_size * elements_count, coo.Ax.data());
251 
252  return Status::Ok;
253  }
254  template<typename T>
256  const auto key_size = sizeof(uint);
257  const auto value_size = sizeof(T);
258 
259  validate_rw(FormatMatrix::CpuCoo);
260  CpuCoo<T>& coo = *get<CpuCoo<T>>();
261 
262  const auto elements_count = coo.Ai.size();
263 
264  keys1 = MemView::make(coo.Ai.data(), key_size * elements_count, false);
265  keys2 = MemView::make(coo.Aj.data(), key_size * elements_count, false);
266  values = MemView::make(coo.Ax.data(), value_size * elements_count, false);
267 
268  return Status::Ok;
269  }
270 
271  template<typename T>
273  m_storage.invalidate();
274  return Status::Ok;
275  }
276 
277  template<typename T>
279  StorageManagerMatrix<T>* manager = get_storage_manager();
280  manager->validate_rw(format, m_storage);
281  }
282 
283  template<typename T>
285  StorageManagerMatrix<T>* manager = get_storage_manager();
286  manager->validate_rwd(format, m_storage);
287  }
288 
289  template<typename T>
291  StorageManagerMatrix<T>* manager = get_storage_manager();
292  manager->validate_wd(format, m_storage);
293  }
294 
295  template<typename T>
297  StorageManagerMatrix<T>* manager = get_storage_manager();
298  manager->validate_ctor(format, m_storage);
299  }
300 
301  template<typename T>
302  bool TMatrix<T>::is_valid(FormatMatrix format) const {
303  return m_storage.is_valid(format);
304  }
305 
306  template<typename T>
308  static std::unique_ptr<StorageManagerMatrix<T>> storage_manager;
309 
310  if (!storage_manager) {
311  storage_manager = std::make_unique<StorageManagerMatrix<T>>();
312  register_formats_matrix(*storage_manager);
313  }
314 
315  return storage_manager.get();
316  }
317 
322 }// namespace spla
323 
324 
325 #endif//SPLA_TMATRIX_HPP
Named storage format for library matrix data objects.
Status of library operation execution.
CPU list of coordinates matrix format.
Definition: cpu_formats.hpp:148
std::vector< uint > Aj
Definition: cpu_formats.hpp:155
std::vector< T > Ax
Definition: cpu_formats.hpp:156
std::vector< uint > Ai
Definition: cpu_formats.hpp:154
std::pair< uint, uint > Key
Definition: cpu_formats.hpp:134
CPU list-of-list matrix format for fast incremental build.
Definition: cpu_formats.hpp:107
Generalized M x N dimensional matrix object.
Definition: matrix.hpp:48
static ref_ptr< MemView > make()
Definition: memview.cpp:67
General format converter for vector or matrix decoration storage.
Definition: storage_manager.hpp:57
Storage for decorators with data of a particular vector or matrix object.
Definition: tdecoration.hpp:70
uint values
Definition: tdecoration.hpp:58
Matrix interface implementation with type information bound.
Definition: tmatrix.hpp:57
~TMatrix() override=default
Decorator * get()
Definition: tmatrix.hpp:81
T get_fill_value() const
Definition: tmatrix.hpp:88
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
ref_ptr< Type > get_type() override
Definition: tmatrix.hpp:111
TMatrix(uint n_rows, uint n_cols)
Definition: tmatrix.hpp:98
void validate_wd(FormatMatrix format)
Definition: tmatrix.hpp:290
void validate_rwd(FormatMatrix format)
Definition: tmatrix.hpp:284
bool is_valid(FormatMatrix format) const
Definition: tmatrix.hpp:302
void validate_wd(F format, Storage &storage)
Definition: storage_manager.hpp:212
static StorageManagerMatrix< T > * get_storage_manager()
Definition: tmatrix.hpp:307
void validate_rw(F format, Storage &storage)
Definition: storage_manager.hpp:128
Status get_int(uint row_id, uint col_id, int32_t &value) override
Definition: tmatrix.hpp:178
const std::string & get_label() const override
Definition: tmatrix.hpp:121
void validate_rw(FormatMatrix format)
Definition: tmatrix.hpp:278
Status set_fill_value(const ref_ptr< Scalar > &value) override
Definition: tmatrix.hpp:131
uint get_n_cols() override
Definition: tmatrix.hpp:107
Status get_uint(uint row_id, uint col_id, uint32_t &value) override
Definition: tmatrix.hpp:192
Status clear() override
Definition: tmatrix.hpp:272
Status set_int(uint row_id, uint col_id, std::int32_t value) override
Definition: tmatrix.hpp:159
Status set_uint(uint row_id, uint col_id, std::uint32_t value) override
Definition: tmatrix.hpp:165
Status get_float(uint row_id, uint col_id, float &value) override
Definition: tmatrix.hpp:206
Status read(ref_ptr< MemView > &keys1, ref_ptr< MemView > &keys2, ref_ptr< MemView > &values) override
Definition: tmatrix.hpp:255
void set_label(std::string label) override
Definition: tmatrix.hpp:116
void cpu_lil_add_element(uint row_id, uint col_id, T element, CpuLil< T > &lil)
Definition: cpu_format_lil.hpp:55
void validate_ctor(F format, Storage &storage)
Definition: storage_manager.hpp:121
void validate_rwd(F format, Storage &storage)
Definition: storage_manager.hpp:206
Status set_float(uint row_id, uint col_id, float value) override
Definition: tmatrix.hpp:171
Status set_format(FormatMatrix format) override
Definition: tmatrix.hpp:126
uint get_n_rows() override
Definition: tmatrix.hpp:103
void validate_ctor(FormatMatrix format)
Definition: tmatrix.hpp:296
Status build(const ref_ptr< MemView > &keys1, const ref_ptr< MemView > &keys2, const ref_ptr< MemView > &values) override
Definition: tmatrix.hpp:221
Status set_reduce(ref_ptr< OpBinary > resolve_duplicates) override
Definition: tmatrix.hpp:145
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
#define LOG_MSG(status, msg)
Definition: logger.hpp:66
Definition: algorithm.hpp:37
void register_formats_matrix(StorageManagerMatrix< T > &manager)
Definition: storage_manager_matrix.hpp:51