53 return "m_extract_column";
57 return "extract matrix column on cpu sequentially";
61 auto t = ctx.
task.template cast_safe<ScheduleTask_m_extract_column>();
62 auto M = t->M.template cast_safe<TMatrix<T>>();
65 return execute_csr(ctx);
68 return execute_lil(ctx);
71 return execute_dok(ctx);
74 return execute_csr(ctx);
81 auto t = ctx.
task.template cast_safe<ScheduleTask_m_extract_column>();
86 uint index = t->index;
92 const CpuDok<T>* p_dok_M = M->template get<CpuDok<T>>();
93 auto& func_apply = op_apply->function;
95 for (
const auto [key, value] : p_dok_M->Ax) {
96 if (key.second == index) {
98 p_coo_r->
Ai.push_back(key.first);
99 p_coo_r->
Ax.push_back(func_apply(value));
108 Status execute_lil(
const DispatchContext& ctx) {
111 auto t = ctx.task.template cast_safe<ScheduleTask_m_extract_column>();
113 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
114 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
115 ref_ptr<TOpUnary<T, T>> op_apply = t->op_apply.template cast_safe<TOpUnary<T, T>>();
116 uint index = t->index;
121 CpuCooVec<T>* p_coo_r = r->template get<CpuCooVec<T>>();
122 const CpuLil<T>* p_lil_M = M->template get<CpuLil<T>>();
123 auto& func_apply = op_apply->function;
125 for (
uint i = 0; i < M->get_n_rows(); i++) {
126 const auto& row = p_lil_M->Ar[i];
129 auto query = std::lower_bound(row.begin(), row.end(), fake, [](
auto& a,
auto& b) { return a.first < b.first; });
131 if (query != row.end() && query->first == index) {
132 p_coo_r->values += 1;
133 p_coo_r->Ai.push_back(i);
134 p_coo_r->Ax.push_back(func_apply(query->second));
141 Status execute_csr(
const DispatchContext& ctx) {
144 auto t = ctx.task.template cast_safe<ScheduleTask_m_extract_column>();
146 ref_ptr<TVector<T>> r = t->r.template cast_safe<TVector<T>>();
147 ref_ptr<TMatrix<T>> M = t->M.template cast_safe<TMatrix<T>>();
148 ref_ptr<TOpUnary<T, T>> op_apply = t->op_apply.template cast_safe<TOpUnary<T, T>>();
149 uint index = t->index;
154 CpuCooVec<T>* p_coo_r = r->template get<CpuCooVec<T>>();
155 const CpuCsr<T>* p_csr_M = M->template get<CpuCsr<T>>();
156 auto& func_apply = op_apply->function;
158 for (
uint i = 0; i < M->get_n_rows(); i++) {
159 const auto row_begin = p_csr_M->Aj.begin() + p_csr_M->Ap[i];
160 const auto row_end = p_csr_M->Aj.begin() + p_csr_M->Ap[i + 1];
162 auto query = std::lower_bound(row_begin, row_end, index);
164 if (query != row_end && *query == index) {
165 p_coo_r->values += 1;
166 p_coo_r->Ai.push_back(i);
167 p_coo_r->Ax.push_back(func_apply(p_csr_M->Ax[std::distance(p_csr_M->Aj.begin(), query)]));
Execution context of a single task.
Definition dispatcher.hpp:46
ref_ptr< ScheduleTask > task
Definition dispatcher.hpp:48