Find the row indexes of several values in a numpy array

前端 未结 6 1967
醉话见心
醉话见心 2020-11-22 02:01

I have an array X:

X = np.array([[4,  2],
              [9,  3],
              [8,  5],
              [3,  3],
              [5,  6]])

And

6条回答
  •  生来不讨喜
    2020-11-22 02:36

    Another alternative is to use asvoid (below) to view each row as a single value of void dtype. This reduces a 2D array to a 1D array, thus allowing you to use np.in1d as usual:

    import numpy as np
    
    def asvoid(arr):
        """
        Based on http://stackoverflow.com/a/16973510/190597 (Jaime, 2013-06)
        View the array as dtype np.void (bytes). The items along the last axis are
        viewed as one value. This allows comparisons to be performed which treat
        entire rows as one value.
        """
        arr = np.ascontiguousarray(arr)
        if np.issubdtype(arr.dtype, np.floating):
            """ Care needs to be taken here since
            np.array([-0.]).view(np.void) != np.array([0.]).view(np.void)
            Adding 0. converts -0. to 0.
            """
            arr += 0.
        return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
    
    X = np.array([[4,  2],
                  [9,  3],
                  [8,  5],
                  [3,  3],
                  [5,  6]])
    
    searched_values = np.array([[4, 2],
                                [3, 3],
                                [5, 6]])
    
    idx = np.flatnonzero(np.in1d(asvoid(X), asvoid(searched_values)))
    print(idx)
    # [0 3 4]
    

提交回复
热议问题