28 #ifndef SPLA_CPU_M_TRANSPOSE_HPP
29 #define SPLA_CPU_M_TRANSPOSE_HPP
56 return "transpose matrix on cpu sequentially";
60 auto t = ctx.
task.template cast_safe<ScheduleTask_m_transpose>();
61 auto M = t->M.template cast_safe<TMatrix<T>>();
64 return execute_csr(ctx);
67 return execute_lil(ctx);
70 return execute_dok(ctx);
73 return execute_csr(ctx);
80 auto t = ctx.
task.template cast_safe<ScheduleTask_m_transpose>();
89 CpuDok<T>* p_dok_R = R->template get<CpuDok<T>>();
90 const CpuDok<T>* p_dok_M = M->template get<CpuDok<T>>();
91 auto& func_apply = op_apply->function;
93 assert(p_dok_R->
Ax.empty());
95 p_dok_R->
Ax.reserve(p_dok_M->Ax.size());
97 for (
const auto& entry : p_dok_M->Ax) {
98 p_dok_R->
Ax[{entry.first.second, entry.first.first}] = func_apply(entry.second);
101 p_dok_R->
values = p_dok_M->values;
106 Status execute_lil(
const DispatchContext& ctx) {
109 auto t = ctx.task.template cast_safe<ScheduleTask_m_transpose>();
111 ref_ptr<TMatrix<T>> R = t->R.template cast_safe<TMatrix<T>>();
112 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
113 ref_ptr<TOpUnary<T, T>> op_apply = t->op_apply.template cast_safe<TOpUnary<T, T>>();
118 CpuLil<T>* p_lil_R = R->template get<CpuLil<T>>();
119 const CpuLil<T>* p_lil_M = M->template get<CpuLil<T>>();
120 auto& func_apply = op_apply->function;
122 const uint DM = M->get_n_rows();
123 const uint DN = M->get_n_cols();
125 assert(M->get_n_rows() == R->get_n_cols());
126 assert(M->get_n_cols() == R->get_n_rows());
128 assert(p_lil_R->Ar.size() == DN);
130 for (
uint i = 0; i < DM; i++) {
131 for (
const auto [j, x] : p_lil_M->Ar[i]) {
132 p_lil_R->Ar[j].emplace_back(i, func_apply(x));
136 p_lil_R->values = p_lil_M->values;
141 Status execute_csr(
const DispatchContext& ctx) {
144 auto t = ctx.task.template cast_safe<ScheduleTask_m_transpose>();
146 ref_ptr<TMatrix<T>> R = t->R.template cast_safe<TMatrix<T>>();
147 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
148 ref_ptr<TOpUnary<T, T>> op_apply = t->op_apply.template cast_safe<TOpUnary<T, T>>();
153 CpuCsr<T>* p_csr_R = R->template get<CpuCsr<T>>();
154 const CpuCsr<T>* p_csr_M = M->template get<CpuCsr<T>>();
155 auto& func_apply = op_apply->function;
157 const uint DM = M->get_n_rows();
158 const uint DN = M->get_n_cols();
160 assert(M->get_n_rows() == R->get_n_cols());
161 assert(M->get_n_cols() == R->get_n_rows());
163 std::vector<uint> sizes(DN + 1, 0);
165 for (
uint i = 0; i < DM; i++) {
166 for (
uint k = p_csr_M->Ap[i]; k < p_csr_M->Ap[i + 1]; k++) {
167 uint j = p_csr_M->Aj[k];
173 std::exclusive_scan(sizes.begin(), sizes.end(), p_csr_R->Ap.begin(), 0);
175 std::vector<uint> offsets(DN, 0);
177 for (
uint i = 0; i < DM; i++) {
178 for (
uint k = p_csr_M->Ap[i]; k < p_csr_M->Ap[i + 1]; k++) {
179 uint j = p_csr_M->Aj[k];
180 T x = p_csr_M->Ax[k];
182 p_csr_R->Aj[p_csr_R->Ap[j] + offsets[j]] = i;
183 p_csr_R->Ax[p_csr_R->Ap[j] + offsets[j]] = func_apply(x);
189 p_csr_R->values = p_csr_M->values;
Status of library operation execution.
Definition: cpu_m_transpose.hpp:47
std::string get_name() override
Definition: cpu_m_transpose.hpp:51
~Algo_m_transpose_cpu() override=default
Status execute(const DispatchContext &ctx) override
Definition: cpu_m_transpose.hpp:59
std::string get_description() override
Definition: cpu_m_transpose.hpp:55
Dictionary of keys sparse matrix format.
Definition: cpu_formats.hpp:128
robin_hood::unordered_flat_map< Key, T, pair_hash > Ax
Definition: cpu_formats.hpp:137
Algorithm suitable to process schedule task based on task string key.
Definition: registry.hpp:66
uint values
Definition: tdecoration.hpp:58
Automates reference counting and behaves as shared smart pointer.
Definition: ref.hpp:117
void cpu_csr_resize(const uint n_rows, const uint n_values, CpuCsr< T > &storage)
Definition: cpu_format_csr.hpp:41
std::uint32_t uint
Library index and size type.
Definition: config.hpp:56
Definition: algorithm.hpp:37
Execution context of a single task.
Definition: dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition: dispatcher.hpp:48
#define TIME_PROFILE_SCOPE(name)
Definition: time_profiler.hpp:92