Match rows of two 2D arrays and get a row indices map using numpy

后端 未结 2 767
不思量自难忘°
不思量自难忘° 2020-12-20 05:18

Suppose you have two 2D arrays A and B, and you want to check, where a row of A is contained in B. How do you do this most efficiently using numpy?

E.g.

<         


        
相关标签:
2条回答
  • 2020-12-20 06:03

    Approach #1

    Here's one based on views. Makes use of np.argwhere (docs) to return the indices of an element that meet a condition, in this case, membership. -

    def view1D(a, b): # a, b are arrays
        a = np.ascontiguousarray(a)
        b = np.ascontiguousarray(b)
        void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
        return a.view(void_dt).ravel(),  b.view(void_dt).ravel()
    
    def argwhere_nd(a,b):
        A,B = view1D(a,b)
        return np.argwhere(A[:,None] == B)
    

    Approach #2

    Here's another that would be O(n) and hence much better on performance, especially on large arrays -

    def argwhere_nd_searchsorted(a,b):
        A,B = view1D(a,b)
        sidxB = B.argsort()
        mask = np.isin(A,B)
        cm = A[mask]
        idx0 = np.flatnonzero(mask)
        idx1 = sidxB[np.searchsorted(B,cm, sorter=sidxB)]
        return idx0, idx1 # idx0 : indices in A, idx1 : indices in B
    

    Approach #3

    Another O(n) one using argsort() -

    def argwhere_nd_argsort(a,b):
        A,B = view1D(a,b)
        c = np.r_[A,B]
        idx = np.argsort(c,kind='mergesort')
        cs = c[idx]
        m0 = cs[:-1] == cs[1:]
        return idx[:-1][m0],idx[1:][m0]-len(A)
    

    Sample runs with same inputs as earlier -

    In [650]: argwhere_nd_searchsorted(a,b)
    Out[650]: (array([0, 1]), array([2, 0]))
    
    In [651]: argwhere_nd_argsort(a,b)
    Out[651]: (array([0, 1]), array([2, 0]))
    
    0 讨论(0)
  • 2020-12-20 06:19

    You can take advantage of the automatic broadcasting:

    np.argwhere(np.all(a.reshape(3,1,-1) == b,2))
    

    which results in

    array([[0, 2],
           [1, 0]])
    

    Note for floats you might want to replace the == with np.islclose()

    0 讨论(0)
提交回复
热议问题