61 auto t = ctx.
task.template cast_safe<ScheduleTask_vxm_masked>();
63 auto r = t->r.template cast_safe<TVector<T>>();
64 auto mask = t->mask.template cast_safe<TVector<T>>();
65 auto v = t->v.template cast_safe<TVector<T>>();
66 auto M = t->M.template cast_safe<TMatrix<T>>();
67 auto op_multiply = t->op_multiply.template cast_safe<TOpBinary<T, T, T>>();
68 auto op_add = t->op_add.template cast_safe<TOpBinary<T, T, T>>();
69 auto op_select = t->op_select.template cast_safe<TOpSelect<T>>();
70 auto init = t->init.template cast_safe<TScalar<T>>();
72 const T sum_init = init->get_value();
79 CpuCooVec<T>* p_sparse_r = r->template get<CpuCooVec<T>>();
80 const CpuDenseVec<T>* p_dense_mask = mask->template get<CpuDenseVec<T>>();
81 const CpuCooVec<T>* p_sparse_v = v->template get<CpuCooVec<T>>();
82 const CpuLil<T>* p_lil_M = M->template get<CpuLil<T>>();
84 auto& func_multiply = op_multiply->function;
85 auto& func_add = op_add->function;
86 auto& func_select = op_select->function;
90 robin_hood::unordered_flat_map<uint, T> r_tmp;
92 for (
uint idx = 0; idx < N; ++idx) {
93 const uint v_i = p_sparse_v->Ai[idx];
94 const T v_x = p_sparse_v->Ax[idx];
96 const auto& row = p_lil_M->Ar[v_i];
98 for (
const auto& j_x : row) {
99 const uint j = j_x.first;
101 if (func_select(p_dense_mask->Ax[j])) {
102 auto r_x = r_tmp.find(j);
104 if (r_x != r_tmp.end())
105 r_x->second = func_add(r_x->second, func_multiply(v_x, j_x.second));
107 r_tmp[j] = func_multiply(v_x, j_x.second);
112 std::vector<std::pair<uint, T>> r_entries;
113 r_entries.reserve(r_tmp.size());
114 for (
const auto& e : r_tmp) {
115 r_entries.emplace_back(e.first, e.second);
117 std::sort(r_entries.begin(), r_entries.end());
120 p_sparse_r->
Ai.reserve(r_tmp.size());
121 p_sparse_r->
Ax.reserve(r_tmp.size());
122 for (
const auto& e : r_entries) {
123 p_sparse_r->
Ai.push_back(e.first);
124 p_sparse_r->
Ax.push_back(e.second);