问题
I am trying to speed up a code that is using Numpy's where() function. There are two calls to where(), which return an array of indices for where the statement is evaluated as True, which are then compared for overlap with numpy's intersect1d() function, of which the length of the intersection is returned.
import numpy as np
def find_match(x,y,z):
A = np.where(x == z)
B = np.where(y == z)
#A = True
#B = True
return len(np.intersect1d(A,B))
N = np.power(10, 8)
M = 10
X = np.random.randint(M, size=N)
Y = np.random.randint(M, size=N)
Z = np.random.randint(M, size=N)
#print(X,Y,Z)
print(find_match(X,Y,Z))
Timing:
This code takes about 8 seconds on my laptop. If I replace both the
np.where()withA=TrueandB=True, then it takes about 5 seconds. If I replace only one of thenp.where()then it takes about 6 seconds.Scaling up, by switching to
N = np.power(10, 9), the code takes 87 seconds. Replacing both thenp.where()statements results in the code takes 51 seconds. Replacing just one of thenp.where()takes about 61 seconds.
My question: How can I merge the two np.where statements that can speed up the code?
What I've tried? Actually, this iteration of the code has improved speed (~4x) by replacing a slower lookup with for-loops. Multiprocessing will be used at a higher level in this code, so I can't apply it also here.
For the record, the actual data are strings, so doing integer math won't be helpful
Version info:
Python 3.9.1 (default, Jan 8 2021, 17:17:43)
[Clang 12.0.0 (clang-1200.0.32.28)] on darwin
>>> import numpy
>>> print(numpy.__version__)
1.19.5
回答1:
Does this help?
def find_match2(x, y, z):
return len(np.nonzero(np.logical_and(x == z, y == z))[0])
Sample run:
In [227]: print(find_match(X,Y,Z))
1000896
In [228]: print(find_match2(X,Y,Z))
1000896
In [229]: %timeit find_match(X,Y,Z)
2.37 s ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [230]: %timeit find_match2(X,Y,Z)
272 ms ± 9.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I've added np.random.seed(210) before creating the arrays for the sake of reproducibility.
回答2:
Two versions that scale differently depending on size:
def find_match1(x,y,z):
return (x==y).astype(int) @ (y==z).astype(int) #equality and summation in one step
def find_match2(x,y,z):
out = np.zeros_like(x)
np.equal(x, y, out = out, where = np.equal(y, z)) #only calculates x==y if y==z
return out.sum()
Testing different data sizes:
N = np.power(10, 7)
...
%timeit find_match(X,Y,Z)
206 ms ± 12.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match1(X,Y,Z)
70.7 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit find_match2(X,Y,Z)
74.7 ms ± 3.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
N = np.power(10, 8)
...
%timeit find_match(X,Y,Z)
2.51 s ± 168 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match1(X,Y,Z)
886 ms ± 154 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit find_match2(X,Y,Z)
776 ms ± 26.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
EDIT: since @Tonechas's is faster than both, here's a numba method:
from numba import njit
@njit
def find_match_jit(x, y, z):
out = 0
for i, j, k in zip(x, y, z):
if i == j and j == k:
out += 1
return out
find_match_jit(X,Y,Z) #run it once to compile
Out[]: 1001426
%timeit find_match_jit(X,Y,Z) # N = 10**8
204 ms ± 13.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
If threading is allowed:
@njit(parallel = True)
def find_match_jit_p(x, y, z):
xy = x == y
yz = y == z
return np.logical_and(xy, yz).sum()
find_match_jit_p(X,Y,Z)
Out[]: 1001426
%timeit find_match_jit_p(X,Y,Z)
84.6 ms ± 2.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
来源:https://stackoverflow.com/questions/65823425/combine-numpy-where-statements