How to pad with zeros a tensor along some axis (Python)

前端 未结 2 1916
陌清茗
陌清茗 2020-12-24 05:27

I would like to pad a numpy tensor with 0 along the chosen axis. For instance, I have tensor r with shape (4,3,2) but I am only interested in padd

相关标签:
2条回答
  • 2020-12-24 06:08

    You can use np.pad():

    a = np.ones((4, 3, 2))
    
    # npad is a tuple of (n_before, n_after) for each dimension
    npad = ((0, 0), (1, 2), (2, 1))
    b = np.pad(a, pad_width=npad, mode='constant', constant_values=0)
    
    print(b.shape)
    # (4, 6, 5)
    
    print(b)
    # [[[ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  0.  0.  0.]]
    
    #  [[ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  0.  0.  0.]]
    
    #  [[ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  0.  0.  0.]]
    
    #  [[ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  1.  1.  0.]
    #   [ 0.  0.  0.  0.  0.]
    #   [ 0.  0.  0.  0.  0.]]]
    
    0 讨论(0)
  • 2020-12-24 06:24

    This function would pad at the end of certain axis.
    If you wish to pad both side, just modify it.

    def pad_along_axis(array: np.ndarray, target_length: int, axis: int = 0):
    
        pad_size = target_length - array.shape[axis]
    
        if pad_size <= 0:
            return array
    
        npad = [(0, 0)] * array.ndim
        npad[axis] = (0, pad_size)
    
        return np.pad(array, pad_width=npad, mode='constant', constant_values=0)
    

    example:

    >>> a = np.identity(5)
    >>> b = pad_along_axis(a, 7, axis=1)
    >>> print(a, a.shape)
    [[1. 0. 0. 0. 0.]
     [0. 1. 0. 0. 0.]
     [0. 0. 1. 0. 0.]
     [0. 0. 0. 1. 0.]
     [0. 0. 0. 0. 1.]] (5, 5)
    
    >>> print(b, b.shape)
    [[1. 0. 0. 0. 0. 0. 0.]
     [0. 1. 0. 0. 0. 0. 0.]
     [0. 0. 1. 0. 0. 0. 0.]
     [0. 0. 0. 1. 0. 0. 0.]
     [0. 0. 0. 0. 1. 0. 0.]] (5, 7)
    
    0 讨论(0)
提交回复
热议问题