Numpy: argmax over multiple axes without loop

点点圈 提交于 2020-06-07 04:33:14

问题


I have a N-dimensional array (Named A). For each row of the first axis of A, I want to obtain the coordinates of the maximum value along the other axes of A. Then I would return a 2-dimensional array with the coordinates of the maximum value for each row of the first axis of A.

I already solved my problem using a loop, but I was wondering whether there is a more efficient way of doing this. My current solution (for an example array A) is as follows:

import numpy as np

A=np.reshape(np.concatenate((np.arange(0,12),np.arange(0,-4,-1))),(4,2,2))
maxpos=np.empty(shape=(4,2))
for n in range(0, 4):
    maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:]), A[n,:,:].shape)

Here, we would have:

A: 
[[[ 0  1]
  [ 2  3]]

 [[ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]]

 [[ 0 -1]
  [-2 -3]]]

maxpos:
[[ 1.  1.]
 [ 1.  1.]
 [ 1.  1.]
 [ 0.  0.]]

If there are multiple maximizers, I don't mind which is chosen.

I have tried to use np.apply_over_axes, but I haven't managed to make it return the outcome I want.


回答1:


You could do something like this -

# Reshape input array to a 2D array with rows being kept as with original array.
# Then, get idnices of max values along the columns.
max_idx = A.reshape(A.shape[0],-1).argmax(1)

# Get unravel indices corresponding to original shape of A
maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))

Sample run -

In [214]: # Input array
     ...: A = np.random.rand(5,4,3,7,8)

In [215]: # Setup output array and use original loopy code
     ...: maxpos=np.empty(shape=(5,4)) # 4 because ndims in A is 5
     ...: for n in range(0, 5):
     ...:     maxpos[n,:]=np.unravel_index(np.argmax(A[n,:,:,:,:]), A[n,:,:,:,:].shape)
     ...:     

In [216]: # Proposed approach
     ...: max_idx = A.reshape(A.shape[0],-1).argmax(1)
     ...: maxpos_vect = np.column_stack(np.unravel_index(max_idx, A[0,:,:].shape))
     ...: 

In [219]: # Verify results
     ...: np.array_equal(maxpos.astype(int),maxpos_vect)
Out[219]: True



回答2:


You can use a list comprehension

result = [np.unravel_index(np.argmax(r), r.shape) for r in a]

it's IMO more readable but the speed is going to be not much better than an explicit loop.

The fact that the main outer loop is in Python should matter only if the first dimension is actually the very big one.

If this is the case (i.e. you have ten millions of 2x2 matrices) then flipping is faster...

# true if 0,0 is not smaller than others
m00 = ((data[:,0,0] >= data[:,1,0]) &
       (data[:,0,0] >= data[:,0,1]) &
       (data[:,0,0] >= data[:,1,1]))

# true if 0,1 is not smaller than others
m01 = ((data[:,0,1] >= data[:,1,0]) &
       (data[:,0,1] >= data[:,0,0]) &
       (data[:,0,1] >= data[:,1,1]))

# true if 1,0 is not smaller than others
m10 = ((data[:,1,0] >= data[:,0,0]) &
       (data[:,1,0] >= data[:,0,1]) &
       (data[:,1,0] >= data[:,1,1]))

# true if 1,1 is not smaller than others
m11 = ((data[:,1,1] >= data[:,1,0]) &
       (data[:,1,1] >= data[:,0,1]) &
       (data[:,1,1] >= data[:,0,0]))

# choose which is max on equality
m01 &= ~m00
m10 &= ~(m00|m01)
m11 &= ~(m00|m01|m10)

# compute result
result = np.zeros((len(data), 2), np.int32)
result[:,1] |= m01|m11
result[:,0] |= m10|m11

On my machine the code above is about 50 times faster (for one million of 2x2 matrices).



来源:https://stackoverflow.com/questions/30589211/numpy-argmax-over-multiple-axes-without-loop

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!