Speed up Python2 nested loops with XOR

后端 未结 3 2112
死守一世寂寞
死守一世寂寞 2020-12-22 12:44

The answer of the question this is marked duplicate of is wrong and does not satisfy my needs.

My code aims to calculate a hash from a seri

相关标签:
3条回答
  • 2020-12-22 13:33

    There is a bug in the accepted answer to Python fast XOR over range algorithm: decrementing l needs to be done before the XOR calculation. Here's a repaired version, along with an assert test to verify that it gives the same result as the naive algorithm.

    def f(a):
        return (a, 1, a + 1, 0)[a % 4]
    
    def getXor(a, b):
        return f(b) ^ f(a-1)
    
    def gen_nums(start, length):
        l = length
        ans = 0
        while l > 0:
            l = l - 1
            ans ^= getXor(start, start + l)
            start += length
        return ans
    
    def answer(start, length):
        c = val = 0
        for i in xrange(length):
            for j in xrange(length - i):
                n = start + c + j
                #print '%d,' % n,
                val ^= n
            #print
            c += length
        return val
    
    for start in xrange(50):
        for length in xrange(100):
            a = answer(start, length)
            b = gen_nums(start, length)
            assert a == b, (start, length, a, b)
    

    Over those ranges of start and length, gen_nums is about 5 times faster than answer, but we can make it roughly twice as fast again (i.e., roughly 10 times as fast as answer) by eliminating those function calls:

    def gen_nums(start, length):
        ans = 0
        for l in xrange(length - 1, -1, -1):
            b = start + l
            ans ^= (b, 1, b + 1, 0)[b % 4] ^ (start - 1, 1, start, 0)[start % 4]
            start += length
        return ans
    

    As Mirek Opoka mentions in the comments, % 4 is equivalent to & 3, and it's faster because bitwise arithmetic is faster than performing integer division and throwing away the quotient. So we can replace the core step with

    ans ^= (b, 1, b + 1, 0)[b & 3] ^ (start - 1, 1, start, 0)[start & 3]
    
    0 讨论(0)
  • 2020-12-22 13:35

    I am afraid that, with the input you have in answer(2000000000,10**4) you'll never finish "in time".

    You can get a pretty significant speed up by improving the inner loop, not updating the c variable every time and using xrange instead of range, like this:

    def answer(start, length):
        val=0
        c=0
        for i in range(length):
            for j in range(length):
                if j < length-i:
                    val^=start+c
                c+=1
        return val
    
    
    def answer_fast(start, length):
        val = 0
        c = 0
        for i in xrange(length):
            for j in xrange(length - i):
                if j < length - i:
                    val ^= start + c + j
            c += length
        return val
    
    
    # print answer(10, 20000)
    print answer_fast(10, 20000)
    

    The profiler shows that answer_fast is about twice as fast:

    > python -m cProfile script.py
    366359392
            20004 function calls in 46.696 seconds
    
    Ordered by: standard name
    
    ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000   46.696   46.696 script.py:1(<module>)
            1   44.357   44.357   46.696   46.696 script.py:1(answer)
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        20001    2.339    0.000    2.339    0.000 {range}
    
    > python -m cProfile script.py
    366359392
            3 function calls in 26.274 seconds
    
    Ordered by: standard name
    
    ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000   26.274   26.274 script.py:1(<module>)
            1   26.274   26.274   26.274   26.274 script.py:12(answer_fast)
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    

    But if you want major speed ups (orders of magnitute) you should consider rewriting your function in Cython.

    Here is the "cythonized" version of it:

    def answer(int start, int length):
        cdef int val = 0, c = 0, i, j
        for i in xrange(length):
            for j in xrange(length - i):
                if j < length - i:
                    val ^= start + c + j
            c += length
        return val
    

    With the same input parameters as above, it takes less than 200ms insted of 20+ seconds, which is a 100x speedup.

    > ipython
    
    In [1]: import pyximport; pyximport.install()
    Out[1]: (None, <pyximport.pyximport.PyxImporter at 0x7f3fed983150>)
    
    In [2]: import script2
    
    In [3]: timeit script2.answer(10, 20000)
    10 loops, best of 3: 188 ms per loop
    

    With your input parameters, it takes 58ms:

    In [5]: timeit script2.answer(2000000000,10**4)
    10 loops, best of 3: 58.2 ms per loop
    
    0 讨论(0)
  • 2020-12-22 13:40

    It looks like you can replace the inner loop and if with:

    for j in range(length - i) val^=start+c c+=1 c+=i This should save some time when i gets bigger

    I'm afraid I can't test this right now, sorry!

    0 讨论(0)
提交回复
热议问题