How does python numpy.where() work?

前端 未结 3 1516
礼貌的吻别
礼貌的吻别 2020-12-04 14:54

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():



        
3条回答
  •  渐次进展
    2020-12-04 15:34

    np.where returns a tuple of length equal to the dimension of the numpy ndarray on which it is called (in other words ndim) and each item of tuple is a numpy ndarray of indices of all those values in the initial ndarray for which the condition is True. (Please don't confuse dimension with shape)

    For example:

    x=np.arange(9).reshape(3,3)
    print(x)
    array([[0, 1, 2],
          [3, 4, 5],
          [6, 7, 8]])
    y = np.where(x>4)
    print(y)
    array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
    


    y is a tuple of length 2 because x.ndim is 2. The 1st item in tuple contains row numbers of all elements greater than 4 and the 2nd item contains column numbers of all items greater than 4. As you can see, [1,2,2,2] corresponds to row numbers of 5,6,7,8 and [2,0,1,2] corresponds to column numbers of 5,6,7,8 Note that the ndarray is traversed along first dimension(row-wise).

    Similarly,

    x=np.arange(27).reshape(3,3,3)
    np.where(x>4)
    


    will return a tuple of length 3 because x has 3 dimensions.

    But wait, there's more to np.where!

    when two additional arguments are added to np.where; it will do a replace operation for all those pairwise row-column combinations which are obtained by the above tuple.

    x=np.arange(9).reshape(3,3)
    y = np.where(x>4, 1, 0)
    print(y)
    array([[0, 0, 0],
       [0, 0, 1],
       [1, 1, 1]])
    

提交回复
热议问题