test for membership in a 2d numpy array

前端 未结 4 751
花落未央
花落未央 2020-12-03 08:05

I have two 2D arrays of the same size

a = array([[1,2],[3,4],[5,6]])
b = array([[1,2],[3,4],[7,8]])

I want to know the rows of b that are i

4条回答
  •  暖寄归人
    2020-12-03 08:47

    What we'd really like to do is use np.in1d... except that np.in1d only works with 1-dimensional arrays. Our arrays are multi-dimensional. However, we can view the arrays as a 1-dimensional array of strings:

    arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
    

    For example,

    In [15]: arr = np.array([[1, 2], [2, 3], [1, 3]])
    
    In [16]: arr = arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
    
    In [30]: arr.dtype
    Out[30]: dtype('V16')
    
    In [31]: arr.shape
    Out[31]: (3, 1)
    
    In [37]: arr
    Out[37]: 
    array([[b'\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00'],
           [b'\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00'],
           [b'\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00']],
          dtype='|V16')
    

    This makes each row of arr a string. Now it is just a matter of hooking this up to np.in1d:

    import numpy as np
    
    def asvoid(arr):
        """
        Based on http://stackoverflow.com/a/16973510/190597 (Jaime, 2013-06)
        View the array as dtype np.void (bytes). The items along the last axis are
        viewed as one value. This allows comparisons to be performed on the entire row.
        """
        arr = np.ascontiguousarray(arr)
        if np.issubdtype(arr.dtype, np.floating):
            """ Care needs to be taken here since
            np.array([-0.]).view(np.void) != np.array([0.]).view(np.void)
            Adding 0. converts -0. to 0.
            """
            arr += 0.
        return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
    
    
    def inNd(a, b, assume_unique=False):
        a = asvoid(a)
        b = asvoid(b)
        return np.in1d(a, b, assume_unique)
    
    
    tests = [
        (np.array([[1, 2], [2, 3], [1, 3]]),
         np.array([[2, 2], [3, 3], [4, 4]]),
         np.array([False, False, False])),
        (np.array([[1, 2], [2, 2], [1, 3]]),
         np.array([[2, 2], [3, 3], [4, 4]]),
         np.array([True, False, False])),
        (np.array([[1, 2], [3, 4], [5, 6]]),
         np.array([[1, 2], [3, 4], [7, 8]]),
         np.array([True, True, False])),
        (np.array([[1, 2], [5, 6], [3, 4]]),
         np.array([[1, 2], [5, 6], [7, 8]]),
         np.array([True, True, False])),
        (np.array([[-0.5, 2.5, -2, 100, 2], [5, 6, 7, 8, 9], [3, 4, 5, 6, 7]]),
         np.array([[1.0, 2, 3, 4, 5], [5, 6, 7, 8, 9], [-0.5, 2.5, -2, 100, 2]]),
         np.array([False, True, True]))
    ]
    
    for a, b, answer in tests:
        result = inNd(b, a)
        try:
            assert np.all(answer == result)
        except AssertionError:
            print('''\
    a:
    {a}
    b:
    {b}
    
    answer: {answer}
    result: {result}'''.format(**locals()))
            raise
    else:
        print('Success!')
    

    yields

    Success!
    

提交回复
热议问题