I have a large csr_matrix and I am interested in the top ten values and their indices each row. But I did not find a decent way to manipulate the matrix.
One would require to iterate over the rows and get the top indices for each row separately. But this loop can be jited(and parallelized) to get extremely fast function.
@nb.njit(cache=True)
def row_topk_csr(data, indices, indptr, K):
m = indptr.shape[0] - 1
max_indices = np.zeros((m, K), dtype=indices.dtype)
max_values = np.zeros((m, K), dtype=data.dtype)
for i in nb.prange(m):
top_inds = np.argsort(data[indptr[i] : indptr[i + 1]])[::-1][:K]
max_indices[i] = indices[indptr[i] : indptr[i + 1]][top_inds]
max_values[i] = data[indptr[i] : indptr[i + 1]][top_inds]
return max_indices, max_values
Call it like this:
top_pred_indices, _ = row_topk_csr(csr_mat.data, csr_mat.indices, csr_mat.indptr, K)
I need to frequently perform this operation, and this function is fast enough for me, executes in <1s on 1mil x 400k sparse matrix.
HTH.