Fastest way to convert a list of indices to 2D numpy array of ones

前端 未结 6 1594
星月不相逢
星月不相逢 2021-01-05 15:11

I have a list of indices

a = [
  [1,2,4],
  [0,2,3],
  [1,3,4],
  [0,2]]

What\'s the fastest way to convert this to a numpy array of ones,

6条回答
  •  一个人的身影
    2021-01-05 15:30

    In case you can and want to use Cython you can create a readable (at least if you don't mind the typing) and fast solution.

    Here I'm using the IPython bindings of Cython to compile it in a Jupyter notebook:

    %load_ext cython
    
    %%cython
    
    cimport cython
    cimport numpy as cnp
    import numpy as np
    
    @cython.boundscheck(False)  # remove this if you cannot guarantee that nrow/ncol are correct
    @cython.wraparound(False)
    cpdef cnp.int_t[:, :] mseifert(list a, int nrow, int ncol):
        cdef cnp.int_t[:, :] out = np.zeros([nrow, ncol], dtype=int)
        cdef list subl
        cdef int row_idx
        cdef int col_idx
        for row_idx, subl in enumerate(a):
            for col_idx in subl:
                out[row_idx, col_idx] = 1
        return out
    

    To compare the performance of the solutions presented here I use my library simple_benchmark:

    Note that this uses logarithmic axis to simultaneously show the differences for small and large arrays. According to my benchmark my function is actually the fastest of the solutions, however it's also worth pointing out that all of the solutions aren't too far off.

    Here is the complete code I used for the benchmark:

    import numpy as np
    from simple_benchmark import BenchmarkBuilder, MultiArgument
    import itertools
    
    b = BenchmarkBuilder()
    
    @b.add_function()
    def pp(a, nrow, ncol):
        sz = np.fromiter(map(len, a), int, nrow)
        out = np.zeros((nrow, ncol), int)
        out[np.arange(nrow).repeat(sz), np.fromiter(itertools.chain.from_iterable(a), int, sz.sum())] = 1
        return out
    
    @b.add_function()
    def ts(a, nrow, ncol):
        out = np.zeros((nrow, ncol), int)
        for i, ix in enumerate(a):
            out[i][ix] = 1
        return out
    
    @b.add_function()
    def u9(a, nrow, ncol):
        out = np.zeros((nrow, ncol), int)
        for i, (x, y) in enumerate(zip(a, out)):
            y[x] = 1
            out[i] = y
        return out
    
    b.add_functions([mseifert])
    
    @b.add_arguments("number of rows/columns")
    def argument_provider():
        for n in range(2, 13):
            ncols = 2**n
            a = [
                sorted(set(np.random.randint(0, ncols, size=np.random.randint(0, ncols)))) 
                for _ in range(ncols)
            ]
            yield ncols, MultiArgument([a, ncols, ncols])
    
    r = b.run()
    r.plot()
    

提交回复
热议问题