How to use numpy.argsort() as indices in more than 2 dimensions?

后端 未结 3 1579
暗喜
暗喜 2020-12-31 09:32

I know something similar to this question has been asked many times over already, but all answers given to similar questions only seem to work for arrays with 2 dimensions.<

3条回答
  •  北荒
    北荒 (楼主)
    2020-12-31 10:28

    Here's a vectorized implementation. It should be N-dimensional and quite a bit faster than what you're doing.

    import numpy as np
    
    
    def sort1(array, args):
        array_sort = np.zeros_like(array)
        for i in range(array.shape[0]):
            for j in range(array.shape[1]):
                array_sort[i, j] = array[i, j, args[i, j]]
    
        return array_sort
    
    
    def sort2(array, args):
        shape = array.shape
        idx = np.ix_(*tuple(np.arange(l) for l in shape[:-1]))
        idx = tuple(ar[..., None] for ar in idx)
        array_sorted = array[idx + (args,)]
    
        return array_sorted
    
    
    if __name__ == '__main__':
        array = np.random.rand(5, 6, 7)
        idx = np.argsort(array)
    
        result1 = sort1(array, idx)
        result2 = sort2(array, idx)
    
        print(np.array_equal(result1, result2))
    

提交回复
热议问题