How does python numpy.where() work?

前端 未结 3 1521
礼貌的吻别
礼貌的吻别 2020-12-04 14:54

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():



        
3条回答
  •  星月不相逢
    2020-12-04 15:25

    How do they achieve internally that you are able to pass something like x > 5 into a method?

    The short answer is that they don't.

    Any sort of logical operation on a numpy array returns a boolean array. (i.e. __gt__, __lt__, etc all return boolean arrays where the given condition is true).

    E.g.

    x = np.arange(9).reshape(3,3)
    print x > 5
    

    yields:

    array([[False, False, False],
           [False, False, False],
           [ True,  True,  True]], dtype=bool)
    

    This is the same reason why something like if x > 5: raises a ValueError if x is a numpy array. It's an array of True/False values, not a single value.

    Furthermore, numpy arrays can be indexed by boolean arrays. E.g. x[x>5] yields [6 7 8], in this case.

    Honestly, it's fairly rare that you actually need numpy.where but it just returns the indicies where a boolean array is True. Usually you can do what you need with simple boolean indexing.

提交回复
热议问题