Understanding the syntax of numpy.r_() concatenation

后端 未结 3 1047
遇见更好的自我
遇见更好的自我 2020-12-15 07:28

I read the following in the numpy documentation for the function r_:

A string integer specifies which axis to stack multiple comma separated arrays

3条回答
  •  春和景丽
    2020-12-15 08:05

    'n,m' tells r_ to concatenate along axis=n, and produce a shape with at least m dimensions:

    In [28]: np.r_['0,2', [1,2,3], [4,5,6]]
    Out[28]: 
    array([[1, 2, 3],
           [4, 5, 6]])
    

    So we are concatenating along axis=0, and we would normally therefore expect the result to have shape (6,), but since m=2, we are telling r_ that the shape must be at least 2-dimensional. So instead we get shape (2,3):

    In [32]: np.r_['0,2', [1,2,3,], [4,5,6]].shape
    Out[32]: (2, 3)
    

    Look at what happens when we increase m:

    In [36]: np.r_['0,3', [1,2,3,], [4,5,6]].shape
    Out[36]: (2, 1, 3)    # <- 3 dimensions
    
    In [37]: np.r_['0,4', [1,2,3,], [4,5,6]].shape
    Out[37]: (2, 1, 1, 3) # <- 4 dimensions
    

    Anything you can do with r_ can also be done with one of the more readable array-building functions such as np.concatenate, np.row_stack, np.column_stack, np.hstack, np.vstack or np.dstack, though it may also require a call to reshape.

    Even with the call to reshape, those other functions may even be faster:

    In [38]: %timeit np.r_['0,4', [1,2,3,], [4,5,6]]
    10000 loops, best of 3: 38 us per loop
    In [43]: %timeit np.concatenate(([1,2,3,], [4,5,6])).reshape(2,1,1,3)
    100000 loops, best of 3: 10.2 us per loop
    

提交回复
热议问题