Scipy.sparse.csr_matrix: How to get top ten values and indices?

后端 未结 3 1106
刺人心
刺人心 2020-12-06 05:40

I have a large csr_matrix and I am interested in the top ten values and their indices each row. But I did not find a decent way to manipulate the matrix.

3条回答
  •  误落风尘
    2020-12-06 06:15

    One would require to iterate over the rows and get the top indices for each row separately. But this loop can be jited(and parallelized) to get extremely fast function.

    @nb.njit(cache=True)
    def row_topk_csr(data, indices, indptr, K):
        m = indptr.shape[0] - 1
        max_indices = np.zeros((m, K), dtype=indices.dtype)
        max_values = np.zeros((m, K), dtype=data.dtype)
    
        for i in nb.prange(m):
            top_inds = np.argsort(data[indptr[i] : indptr[i + 1]])[::-1][:K]
            max_indices[i] = indices[indptr[i] : indptr[i + 1]][top_inds]
            max_values[i] = data[indptr[i] : indptr[i + 1]][top_inds]
    
        return max_indices, max_values
    

    Call it like this:

    top_pred_indices, _ = row_topk_csr(csr_mat.data, csr_mat.indices, csr_mat.indptr, K)
    

    I need to frequently perform this operation, and this function is fast enough for me, executes in <1s on 1mil x 400k sparse matrix.

    HTH.

提交回复
热议问题