Broadcasting a python function on to numpy arrays

前端 未结 3 1298
鱼传尺愫
鱼传尺愫 2020-12-15 11:08

Let\'s say we have a particularly simple function like

import scipy as sp
def func(x, y):
   return x + y

This function evidently works for

3条回答
  •  [愿得一人]
    2020-12-15 11:23

    np.vectorize is a general way to convert Python functions that operate on numbers into numpy functions that operate on ndarrays.

    However, as you point out, it isn't very fast, since it is using a Python loop "under the hood".

    To achieve better speed, you have to hand-craft a function that expects numpy arrays as input and takes advantage of that numpy-ness:

    import numpy as np
    
    def func2(x, y):
        return np.where(x>y,x+y,x-y)      
    
    x = np.array([-2, -1, 0, 1, 2])
    y = np.array([-2, -1, 0, 1, 2])
    
    xx = x[:, np.newaxis]
    yy = y[np.newaxis, :]
    
    print(func2(xx, yy))
    # [[ 0 -1 -2 -3 -4]
    #  [-3  0 -1 -2 -3]
    #  [-2 -1  0 -1 -2]
    #  [-1  0  1  0 -1]
    #  [ 0  1  2  3  0]]
    

    Regarding performance:

    test.py:

    import numpy as np
    
    def func2a(x, y):
        return np.where(x>y,x+y,x-y)      
    
    def func2b(x, y):
        ind=x>y
        z=np.empty(ind.shape,dtype=x.dtype)
        z[ind]=(x+y)[ind]
        z[~ind]=(x-y)[~ind]
        return z
    
    def func2c(x, y):
        # x, y= x[:, None], y[None, :]
        A, L= x+ y, x<= y
        A[L]= (x- y)[L]
        return A
    
    N=40
    x = np.random.random(N)
    y = np.random.random(N)
    
    xx = x[:, np.newaxis]
    yy = y[np.newaxis, :]
    

    Running:

    With N=30:

    % python -mtimeit -s'import test' 'test.func2a(test.xx,test.yy)'
    1000 loops, best of 3: 219 usec per loop
    
    % python -mtimeit -s'import test' 'test.func2b(test.xx,test.yy)'
    1000 loops, best of 3: 488 usec per loop
    
    % python -mtimeit -s'import test' 'test.func2c(test.xx,test.yy)'
    1000 loops, best of 3: 248 usec per loop
    

    With N=1000:

    % python -mtimeit -s'import test' 'test.func2a(test.xx,test.yy)'
    10 loops, best of 3: 93.7 msec per loop
    
    % python -mtimeit -s'import test' 'test.func2b(test.xx,test.yy)'
    10 loops, best of 3: 367 msec per loop
    
    % python -mtimeit -s'import test' 'test.func2c(test.xx,test.yy)'
    10 loops, best of 3: 186 msec per loop
    

    This seems to suggest that func2a is slightly faster than func2c (and func2b is horribly slow).

提交回复
热议问题