Shift rows of a numpy array independently

前端 未结 2 1518
暖寄归人
暖寄归人 2020-12-06 15:04

This is an extension of the question posed here (quoted below)

I have a matrix (2d numpy ndarray, to be precise):

A = np.array([[4, 0,         


        
相关标签:
2条回答
  • 2020-12-06 15:27

    Inspired by Roll rows of a matrix independently's solution, here's a vectorized one based on np.lib.stride_tricks.as_strided -

    from skimage.util.shape import view_as_windows as viewW
    
    def strided_indexing_roll(a, r):
        # Concatenate with sliced to cover all rolls
        p = np.full((a.shape[0],a.shape[1]-1),np.nan)
        a_ext = np.concatenate((p,a,p),axis=1)
    
        # Get sliding windows; use advanced-indexing to select appropriate ones
        n = a.shape[1]
        return viewW(a_ext,(1,n))[np.arange(len(r)), -r + (n-1),0]
    

    Sample run -

    In [76]: a
    Out[76]: 
    array([[4, 0, 0],
           [1, 2, 3],
           [0, 0, 5]])
    
    In [77]: r
    Out[77]: array([ 2,  0, -1])
    
    In [78]: strided_indexing_roll(a, r)
    Out[78]: 
    array([[nan, nan,  4.],
           [ 1.,  2.,  3.],
           [ 0.,  5., nan]])
    
    0 讨论(0)
  • 2020-12-06 15:38

    I was able to hack this together with linear indexing...it gets the right result but performs rather slowly on large arrays.

    A = np.array([[4, 0, 0],
                  [1, 2, 3],
                  [0, 0, 5]]).astype(float)
    
    r = np.array([2, 0, -1])
    
    rows, column_indices = np.ogrid[:A.shape[0], :A.shape[1]]
    
    # Use always a negative shift, so that column_indices are valid.
    # (could also use module operation)
    r_old = r.copy()
    r[r < 0] += A.shape[1]
    column_indices = column_indices - r[:,np.newaxis]
    
    result = A[rows, column_indices]
    
    # replace with NaNs
    row_length = result.shape[-1]
    
    pad_inds = []
    for ind,i in np.enumerate(r_old):
        if i > 0:
            inds2pad = [np.ravel_multi_index((ind,) + (j,),result.shape) for j in range(i)]
            pad_inds.extend(inds2pad)
        if i < 0:
            inds2pad = [np.ravel_multi_index((ind,) + (j,),result.shape) for j in range(row_length+i,row_length)]
            pad_inds.extend(inds2pad)
    result.ravel()[pad_inds] = nan
    

    Gives the expected result:

    print result
    
    [[ nan  nan   4.]
     [  1.   2.   3.]
     [  0.   5.  nan]]
    
    0 讨论(0)
提交回复
热议问题