Suppose I have a matrix A with some arbitrary values:
array([[ 2, 4, 5, 3],
[ 1, 6, 8, 9],
[ 8, 7, 0, 2]])
A
More recent versions have added a take_along_axis
function that does the job:
In [203]: A = np.array([[ 2, 4, 5, 3],
...: [ 1, 6, 8, 9],
...: [ 8, 7, 0, 2]])
In [204]: B = np.array([[0, 0, 1, 2],
...: [0, 3, 2, 1],
...: [3, 2, 1, 0]])
In [205]: np.take_along_axis(A,B,1)
Out[205]:
array([[2, 2, 4, 5],
[1, 9, 8, 6],
[2, 0, 7, 8]])
There's also a put_along_axis
.