Numpy advanced selection not working

旧时模样 提交于 2019-11-28 12:54:14

You want

b[np.ix_([0, 1], [0, 1, 2])]

You also need to do the same thing for b[[0, 1], [0, 1]], because that's not actually doing what you think it is:

b[np.ix_([0, 1], [0, 1])]

The problem here is that advanced indexing does something completely different from what you think it does. You've made the mistake of thinking that b[[0, 1], [0, 1, 2]] means "take all parts b[i, j] of b where i is 0 or 1 and j is 0, 1, or 2". This is a reasonable mistake to make, considering that it seems to work that way when you have one list in the indexing expression, like

b[:, [1, 3, 5], 2]

In fact, for an array A and one-dimensional integer arrays I and J, A[I, J] is an array where

A[I, J][n] == A[I[n], J[n]]

This generalizes in the natural way to more index arrays, so for example

A[I, J, K][n] == A[I[n], J[n], K[n]]

and to higher-dimensional index arrays, so if I and J are two-dimensional, then

A[I, J][m, n] == A[I[m, n], J[m, n]]

It also applies the broadcasting rules to the index arrays, and converts lists in the indexes to arrays. This is much more powerful than what you expected to happen, but it means that to do what you were trying to do, you need something like

b[[[0],
   [1]], [[0, 1, 2]]]

np.ix_ is a helper that will do that for you so you don't have to write a dozen brackets.

I think you misunderstood the advanced selection syntax for this case. I used your example, just made it smaller to be easier to see.

import numpy as np
b = np.random.rand(5, 4, 3, 2)

# advanced selection works as expected
print b[[0,1],[0,1]]   # http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
                       # this picks the two i,j=0 (a 3x2 matrix) and i=1,j=1, another 3x2 matrix

# doesn't work - why?
#print  b[[0,1],[0,1,2]]   # this doesnt' work because [0,1] and [0,1,2] have different lengths

print b[[0,1,2],[0,1,2]]  # works

Output:

[[[ 0.27334558  0.90065184]
  [ 0.8624593   0.34324983]
  [ 0.19574819  0.2825373 ]]

 [[ 0.38660087  0.63941692]
  [ 0.81522421  0.16661912]
  [ 0.81518479  0.78655536]]]
[[[ 0.27334558  0.90065184]
  [ 0.8624593   0.34324983]
  [ 0.19574819  0.2825373 ]]

 [[ 0.38660087  0.63941692]
  [ 0.81522421  0.16661912]
  [ 0.81518479  0.78655536]]

 [[ 0.65336551  0.1435357 ]
  [ 0.91380873  0.45225145]
  [ 0.57255923  0.7645396 ]]]
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!