If I have an ndarray like this:
>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]
I use index_at() to create the full index:
import numpy as np
def index_at(idx, shape, axis=-1):
if axis<0:
axis += len(shape)
shape = shape[:axis] + shape[axis+1:]
index = list(np.ix_(*[np.arange(n) for n in shape]))
index.insert(axis, idx)
return tuple(index)
a = np.random.randint(0, 10, (3, 4, 5))
axis = 1
idx = np.argmax(a, axis=axis)
print a[index_at(idx, a.shape, axis=axis)]
print np.max(a, axis=axis)