Scipy.sparse.csr_matrix: How to get top ten values and indices?

后端 未结 3 1113
刺人心
刺人心 2020-12-06 05:40

I have a large csr_matrix and I am interested in the top ten values and their indices each row. But I did not find a decent way to manipulate the matrix.

3条回答
  •  孤街浪徒
    2020-12-06 05:59

    Just to answer the original question (for people like me who found this question looking for copy-pasta), here's a solution using multiprocessing based on @hpaulj's suggestion of converting to lil_matrix, and iterating over rows

    from multiprocessing import Pool
    
    def _top_k(args):
        """
        Helper function to process a single row of top_k
        """
        data, row = args
        data, row = zip(*sorted(zip(data, row), reverse=True)[:k])
        return data, row
    
    def top_k(m, k):
        """
        Keep only the top k elements of each row in a csr_matrix
        """
        ml = m.tolil()
        with Pool() as p:
            ms = p.map(_top_k, zip(ml.data, ml.rows))
        ml.data, ml.rows = zip(*ms)
        return ml.tocsr()
    

提交回复
热议问题