I have a large csr_matrix and I am interested in the top ten values and their indices each row. But I did not find a decent way to manipulate the matrix.
Just to answer the original question (for people like me who found this question looking for copy-pasta), here's a solution using multiprocessing based on @hpaulj's suggestion of converting to lil_matrix, and iterating over rows
from multiprocessing import Pool
def _top_k(args):
"""
Helper function to process a single row of top_k
"""
data, row = args
data, row = zip(*sorted(zip(data, row), reverse=True)[:k])
return data, row
def top_k(m, k):
"""
Keep only the top k elements of each row in a csr_matrix
"""
ml = m.tolil()
with Pool() as p:
ms = p.map(_top_k, zip(ml.data, ml.rows))
ml.data, ml.rows = zip(*ms)
return ml.tocsr()