Numpy shuffle multidimensional array by row only, keep column order unchanged

后端 未结 6 639
不思量自难忘°
不思量自难忘° 2020-12-05 12:52

How can I shuffle a multidimensional array by row only in Python (so do not shuffle the columns).

I am looking for the most efficient solution, because my ma

6条回答
  •  谎友^
    谎友^ (楼主)
    2020-12-05 13:08

    After a bit experiment i found most memory and time efficient way to shuffle data(row wise) of nd-array is, shuffle the index and get the data from shuffled index

    rand_num2 = np.random.randint(5, size=(6000, 2000))
    perm = np.arange(rand_num2.shape[0])
    np.random.shuffle(perm)
    rand_num2 = rand_num2[perm]
    

    in more details
    Here, I am using memory_profiler to find memory usage and python's builtin "time" module to record time and comparing all previous answers

    def main():
        # shuffle data itself
        rand_num = np.random.randint(5, size=(6000, 2000))
        start = time.time()
        np.random.shuffle(rand_num)
        print('Time for direct shuffle: {0}'.format((time.time() - start)))
    
        # Shuffle index and get data from shuffled index
        rand_num2 = np.random.randint(5, size=(6000, 2000))
        start = time.time()
        perm = np.arange(rand_num2.shape[0])
        np.random.shuffle(perm)
        rand_num2 = rand_num2[perm]
        print('Time for shuffling index: {0}'.format((time.time() - start)))
    
        # using np.take()
        rand_num3 = np.random.randint(5, size=(6000, 2000))
        start = time.time()
        np.take(rand_num3, np.random.rand(rand_num3.shape[0]).argsort(), axis=0, out=rand_num3)
        print("Time taken by np.take, {0}".format((time.time() - start)))
    

    Result for Time

    Time for direct shuffle: 0.03345608711242676   # 33.4msec
    Time for shuffling index: 0.019818782806396484 # 19.8msec
    Time taken by np.take, 0.06726956367492676     # 67.2msec
    

    Memory profiler Result

    Line #    Mem usage    Increment   Line Contents
    ================================================
        39  117.422 MiB    0.000 MiB   @profile
        40                             def main():
        41                                 # shuffle data itself
        42  208.977 MiB   91.555 MiB       rand_num = np.random.randint(5, size=(6000, 2000))
        43  208.977 MiB    0.000 MiB       start = time.time()
        44  208.977 MiB    0.000 MiB       np.random.shuffle(rand_num)
        45  208.977 MiB    0.000 MiB       print('Time for direct shuffle: {0}'.format((time.time() - start)))
        46                             
        47                                 # Shuffle index and get data from shuffled index
        48  300.531 MiB   91.555 MiB       rand_num2 = np.random.randint(5, size=(6000, 2000))
        49  300.531 MiB    0.000 MiB       start = time.time()
        50  300.535 MiB    0.004 MiB       perm = np.arange(rand_num2.shape[0])
        51  300.539 MiB    0.004 MiB       np.random.shuffle(perm)
        52  300.539 MiB    0.000 MiB       rand_num2 = rand_num2[perm]
        53  300.539 MiB    0.000 MiB       print('Time for shuffling index: {0}'.format((time.time() - start)))
        54                             
        55                                 # using np.take()
        56  392.094 MiB   91.555 MiB       rand_num3 = np.random.randint(5, size=(6000, 2000))
        57  392.094 MiB    0.000 MiB       start = time.time()
        58  392.242 MiB    0.148 MiB       np.take(rand_num3, np.random.rand(rand_num3.shape[0]).argsort(), axis=0, out=rand_num3)
        59  392.242 MiB    0.000 MiB       print("Time taken by np.take, {0}".format((time.time() - start)))
    

提交回复
热议问题