Performance nested loop in numba

孤街醉人 提交于 2019-12-20 02:55:06

问题


For performance reasons, I have started to use Numba besides NumPy. My Numba algorithm is working, but I have the feeling that it should be faster. There is one point which is slowing it down. Here is the code snippet:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1

In my opinion the if command is slowing it down. Is there a better way? (What I try to achieve here is related to a previous posted problem: Count possibilites for single crossovers) ws is a NumPy array of size (gn, l) containing 0's and 1's


回答1:


Given the logic of wanting to ensure all items are equal, you can take advantage of the fact that if any are not equal, you can short-circuit (i.e stop comparing) the calculation. I modified your original function slightly so that (1) you don't repeat the same comparison twice, and (2) sum y over the all nested loops so there was a return that could be compared:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                        ysum += 1

    return ysum


@nb.njit
def rfunc2(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):

                    incr_y = True
                    for j in range(i):
                        if ws[x1,j] != ws[x2,j]:
                            incr_y = False
                            break

                    if incr_y is True:
                        for j in range(i,l):
                            if ws[x1,j] != ws[x3,j]:
                                incr_y = False
                                break
                    if incr_y is True:
                        y += 1
                        ysum += 1
    return ysum

I don't know what the complete function looks like, but hopefully this helps you get started on the right path.

Now for some timings:

l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:

%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop


%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop

In [27]: rfunc1(ws, a , l)
Out[27]: 131919

In [30]: rfunc2(ws, a , l)
Out[30]: 131919

That gives you a 50x speed-up.




回答2:


Instead of just "having a feeling" where your bottleneck is, why not profile your code and find exactly where?

The first aim of profiling is to test a representative system to identify what’s slow (or using too much RAM, or causing too much disk I/O or network I/O).

Profiling typically adds an overhead (10x to 100x slowdowns can be typical), and you still want your code to be used as similarly to in a real-world situation as possible. Extract a test case and isolate the piece of the system that you need to test. Preferably, it’ll have been written to be in its own set of modules already.

Basic techniques include the %timeit magic in IPython, time.time(), and a timing decorator (see example below). You can use these techniques to understand the behavior of statements and functions.

Then you have cProfile which will give you a high-level view of the problem so you can direct your attention to the critical functions.

Next, look at line_profiler, which will profile your chosen functions on a line-by-line basis. The result will include a count of the number of times each line is called and the percentage of time spent on each line. This is exactly the information you need to understand what’s running slowly and why.

perf stat helps you understand the number of instructions that are ultimately executed on a CPU and how efficiently the CPU’s caches are utilized. This allows for advanced-level tuning of matrix operations.

heapy can track all of the objects inside Python’s memory. This is great for hunting down strange memory leaks. If you’re working with long-running systems, then dowser will interest you: it allows you to introspect live objects in a long-running process via a web browser interface.

To help you understand why your RAM usage is high, check out memory_profiler. It is particularly useful for tracking RAM usage over time on a labeled chart, so you can explain to colleagues (or yourself) why certain functions use more RAM than expected.

Example: Defining a decorator to automate timing measurements

from functools import wraps

def timefn(fn):
    @wraps(fn)
    def measure_time(*args, **kwargs):
        t1 = time.time()
        result = fn(*args, **kwargs)
        t2 = time.time()
        print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds")
        return result
    return measure_time

@timefn
def your_func(var1, var2):
    ...

For more information, I suggest reading High performance Python (Micha Gorelick; Ian Ozsvald) from which the above was sourced.



来源:https://stackoverflow.com/questions/41051553/performance-nested-loop-in-numba

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!