Numpy: Index 3D array with index of last axis stored in 2D array

前端 未结 4 547
北海茫月
北海茫月 2020-12-19 00:18

I have a ndarray of shape(z,y,x) containing values. I am trying to index this array with another ndarray of shape(y,x) th

4条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-12-19 00:36

    With readability, np.choose definitely looks great.

    If performance is of essence, you can calculate the linear indices and then use np.take or use a flattened version with .ravel() and extract those specific elements from val_arr. The implementation would look something like this -

    def linidx_take(val_arr,z_indices):
    
        # Get number of columns and rows in values array
         _,nC,nR = val_arr.shape
    
         # Get linear indices and thus extract elements with np.take
        idx = nC*nR*z_indices + nR*np.arange(nR)[:,None] + np.arange(nC)
        return np.take(val_arr,idx) # Or val_arr.ravel()[idx]
    

    Runtime tests and verify results -

    Ogrid based solution from here is made into a generic version for these tests, like so :

    In [182]: def ogrid_based(val_arr,z_indices):
         ...:   v_shp = val_arr.shape
         ...:   y,x = np.ogrid[0:v_shp[1], 0:v_shp[2]]
         ...:   return val_arr[z_indices, y, x]
         ...: 
    

    Case #1: Smaller datasize

    In [183]: val_arr = np.random.rand(30,30,30)
         ...: z_indices = np.random.randint(0,30,(30,30))
         ...: 
    
    In [184]: np.allclose(z_indices.choose(val_arr),ogrid_based(val_arr,z_indices))
    Out[184]: True
    
    In [185]: np.allclose(z_indices.choose(val_arr),linidx_take(val_arr,z_indices))
    Out[185]: True
    
    In [187]: %timeit z_indices.choose(val_arr)
    1000 loops, best of 3: 230 µs per loop
    
    In [188]: %timeit ogrid_based(val_arr,z_indices)
    10000 loops, best of 3: 54.1 µs per loop
    
    In [189]: %timeit linidx_take(val_arr,z_indices)
    10000 loops, best of 3: 30.3 µs per loop
    

    Case #2: Bigger datasize

    In [191]: val_arr = np.random.rand(300,300,300)
         ...: z_indices = np.random.randint(0,300,(300,300))
         ...: 
    
    In [192]: z_indices.choose(val_arr) # Seems like there is some limitation here with bigger arrays.
    Traceback (most recent call last):
    
      File "", line 1, in 
        z_indices.choose(val_arr)
    
    ValueError: Need between 2 and (32) array objects (inclusive).
    
    
    In [194]: np.allclose(linidx_take(val_arr,z_indices),ogrid_based(val_arr,z_indices))
    Out[194]: True
    
    In [195]: %timeit ogrid_based(val_arr,z_indices)
    100 loops, best of 3: 3.67 ms per loop
    
    In [196]: %timeit linidx_take(val_arr,z_indices)
    100 loops, best of 3: 2.04 ms per loop
    

提交回复
热议问题