问题
For performance reasons, I have started to use Numba besides NumPy. My Numba algorithm is working, but I have the feeling that it should be faster. There is one point which is slowing it down. Here is the code snippet:
@nb.njit
def rfunc1(ws, a, l):
gn = a**l
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
numpy.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
numpy.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
In my opinion the if
command is slowing it down. Is there a better way? (What I try to achieve here is related to a previous posted problem: Count possibilites for single crossovers) ws
is a NumPy array of size (gn, l)
containing 0
's and 1
's
回答1:
Given the logic of wanting to ensure all items are equal, you can take advantage of the fact that if any are not equal, you can short-circuit (i.e stop comparing) the calculation. I modified your original function slightly so that (1) you don't repeat the same comparison twice, and (2) sum y over the all nested loops so there was a return that could be compared:
@nb.njit
def rfunc1(ws, a, l):
gn = a**l
ysum = 0
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
ysum += 1
return ysum
@nb.njit
def rfunc2(ws, a, l):
gn = a**l
ysum = 0
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
incr_y = True
for j in range(i):
if ws[x1,j] != ws[x2,j]:
incr_y = False
break
if incr_y is True:
for j in range(i,l):
if ws[x1,j] != ws[x3,j]:
incr_y = False
break
if incr_y is True:
y += 1
ysum += 1
return ysum
I don't know what the complete function looks like, but hopefully this helps you get started on the right path.
Now for some timings:
l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:
%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop
%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop
In [27]: rfunc1(ws, a , l)
Out[27]: 131919
In [30]: rfunc2(ws, a , l)
Out[30]: 131919
That gives you a 50x speed-up.
回答2:
Instead of just "having a feeling" where your bottleneck is, why not profile your code and find exactly where?
The first aim of profiling is to test a representative system to identify what’s slow (or using too much RAM, or causing too much disk I/O or network I/O).
Profiling typically adds an overhead (10x to 100x slowdowns can be typical), and you still want your code to be used as similarly to in a real-world situation as possible. Extract a test case and isolate the piece of the system that you need to test. Preferably, it’ll have been written to be in its own set of modules already.
Basic techniques include the %timeit
magic in IPython, time.time(),
and a timing decorator
(see example below). You can use these techniques to understand the behavior of statements and functions.
Then you have cProfile
which will give you a high-level view of the problem so you can direct your attention to the critical functions.
Next, look at line_profiler,
which will profile your chosen functions on a line-by-line basis. The result will include a count of the number of times each line is called and the percentage of time spent on each line. This is exactly the information you need to understand what’s running slowly and why.
perf stat
helps you understand the number of instructions that are ultimately executed on a CPU and how efficiently the CPU’s caches are utilized. This allows for advanced-level tuning of matrix operations.
heapy
can track all of the objects inside Python’s memory. This is great for hunting down strange memory leaks. If you’re working with long-running systems,
then dowser
will interest you: it allows you to introspect live objects in a long-running process via a web browser interface.
To help you understand why your RAM usage is high, check out memory_profiler.
It is particularly useful for tracking RAM usage over time on a labeled chart, so you can explain to colleagues (or yourself) why certain functions use more RAM than expected.
Example: Defining a decorator to automate timing measurements
from functools import wraps
def timefn(fn):
@wraps(fn)
def measure_time(*args, **kwargs):
t1 = time.time()
result = fn(*args, **kwargs)
t2 = time.time()
print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds")
return result
return measure_time
@timefn
def your_func(var1, var2):
...
For more information, I suggest reading High performance Python (Micha Gorelick; Ian Ozsvald) from which the above was sourced.
来源:https://stackoverflow.com/questions/41051553/performance-nested-loop-in-numba