Numba code slower than pure python

时光总嘲笑我的痴心妄想 提交于 2019-12-02 18:23:17

The problem is that numba can't intuit the type of lookup. If you put a print nb.typeof(lookup) in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.

@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

Then my timings are:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

You can go even a little faster still if you use jit instead of autojit:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.

Faster numpy version (10x speedup compared to numpy_resample)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!