Numpy `logical_or` for more than two arguments

后端 未结 7 1006
庸人自扰
庸人自扰 2020-11-22 09:33

Numpy\'s logical_or function takes no more than two arrays to compare. How can I find the union of more than two arrays? (The same question could be asked wit

7条回答
  •  鱼传尺愫
    2020-11-22 10:08

    I've tried the following three different methods to get the logical_and of a list l of k arrays of size n:

    1. Using a recursive numpy.logical_and (see below)
    2. Using numpy.logical_and.reduce(l)
    3. Using numpy.vstack(l).all(axis=0)

    Then I did the same for the logical_or function. Surprisingly enough, the recursive method is the fastest one.

    import numpy
    import perfplot
    
    def and_recursive(*l):
        if len(l) == 1:
            return l[0].astype(bool)
        elif len(l) == 2:
            return numpy.logical_and(l[0],l[1])
        elif len(l) > 2:
            return and_recursive(and_recursive(*l[:2]),and_recursive(*l[2:]))
    
    def or_recursive(*l):
        if len(l) == 1:
            return l[0].astype(bool)
        elif len(l) == 2:
            return numpy.logical_or(l[0],l[1])
        elif len(l) > 2:
            return or_recursive(or_recursive(*l[:2]),or_recursive(*l[2:]))
    
    def and_reduce(*l):
        return numpy.logical_and.reduce(l)
    
    def or_reduce(*l):
        return numpy.logical_or.reduce(l)
    
    def and_stack(*l):
        return numpy.vstack(l).all(axis=0)
    
    def or_stack(*l):
        return numpy.vstack(l).any(axis=0)
    
    k = 10 # number of arrays to be combined
    
    perfplot.plot(
        setup=lambda n: [numpy.random.choice(a=[False, True], size=n) for j in range(k)],
        kernels=[
            lambda l: and_recursive(*l),
            lambda l: and_reduce(*l),
            lambda l: and_stack(*l),
            lambda l: or_recursive(*l),
            lambda l: or_reduce(*l),
            lambda l: or_stack(*l),
        ],
        labels = ['and_recursive', 'and_reduce', 'and_stack', 'or_recursive', 'or_reduce', 'or_stack'],
        n_range=[2 ** j for j in range(20)],
        logx=True,
        logy=True,
        xlabel="len(a)",
        equality_check=None
    )
    

    Here below the performances for k = 4.

    And here below the performances for k = 10.

    It seems that there is an approximately constant time overhead also for higher n.

提交回复
热议问题