I have a ndarray
of shape(z,y,x)
containing values. I am trying to index this array with another ndarray
of shape(y,x)
th
With readability, np.choose
definitely looks great.
If performance is of essence, you can calculate the linear indices and then use np.take or use a flattened version with .ravel() and extract those specific elements from val_arr
. The implementation would look something like this -
def linidx_take(val_arr,z_indices):
# Get number of columns and rows in values array
_,nC,nR = val_arr.shape
# Get linear indices and thus extract elements with np.take
idx = nC*nR*z_indices + nR*np.arange(nR)[:,None] + np.arange(nC)
return np.take(val_arr,idx) # Or val_arr.ravel()[idx]
Runtime tests and verify results -
Ogrid based solution from here is made into a generic version for these tests, like so :
In [182]: def ogrid_based(val_arr,z_indices):
...: v_shp = val_arr.shape
...: y,x = np.ogrid[0:v_shp[1], 0:v_shp[2]]
...: return val_arr[z_indices, y, x]
...:
Case #1: Smaller datasize
In [183]: val_arr = np.random.rand(30,30,30)
...: z_indices = np.random.randint(0,30,(30,30))
...:
In [184]: np.allclose(z_indices.choose(val_arr),ogrid_based(val_arr,z_indices))
Out[184]: True
In [185]: np.allclose(z_indices.choose(val_arr),linidx_take(val_arr,z_indices))
Out[185]: True
In [187]: %timeit z_indices.choose(val_arr)
1000 loops, best of 3: 230 µs per loop
In [188]: %timeit ogrid_based(val_arr,z_indices)
10000 loops, best of 3: 54.1 µs per loop
In [189]: %timeit linidx_take(val_arr,z_indices)
10000 loops, best of 3: 30.3 µs per loop
Case #2: Bigger datasize
In [191]: val_arr = np.random.rand(300,300,300)
...: z_indices = np.random.randint(0,300,(300,300))
...:
In [192]: z_indices.choose(val_arr) # Seems like there is some limitation here with bigger arrays.
Traceback (most recent call last):
File "", line 1, in
z_indices.choose(val_arr)
ValueError: Need between 2 and (32) array objects (inclusive).
In [194]: np.allclose(linidx_take(val_arr,z_indices),ogrid_based(val_arr,z_indices))
Out[194]: True
In [195]: %timeit ogrid_based(val_arr,z_indices)
100 loops, best of 3: 3.67 ms per loop
In [196]: %timeit linidx_take(val_arr,z_indices)
100 loops, best of 3: 2.04 ms per loop