Faster implementation for ReLu derivative in python?

心不动则不痛 提交于 2019-12-22 17:48:56

问题


I have implemented ReLu derivative as:

def relu_derivative(x):
     return (x>0)*np.ones(x.shape)

I also tried:

def relu_derivative(x):
   x[x>=0]=1
   x[x<0]=0
   return x

Size of X=(3072,10000). But it's taking much time to compute. Is there any other optimized solution?


回答1:


Approach #1 : Using numexpr

When working with large data, we can use numexpr module that supports multi-core processing if the intended operations could be expressed as arithmetic ones. Here, one way would be -

(X>=0)+0

Thus, to solve our case, it would be -

import numexpr as ne

ne.evaluate('(X>=0)+0')

Approach #2 : Using NumPy views

Another trick would be to use views by viewing the mask of comparisons as an int array, like so -

(X>=0).view('i1')

On performance, it should be identical to creating X>=0.

Timings

Comparing all posted solutions on a random array -

In [14]: np.random.seed(0)
    ...: X = np.random.randn(3072,10000)

In [15]: # OP's soln-1
    ...: def relu_derivative_v1(x):
    ...:      return (x>0)*np.ones(x.shape)
    ...: 
    ...: # OP's soln-2     
    ...: def relu_derivative_v2(x):
    ...:    x[x>=0]=1
    ...:    x[x<0]=0
    ...:    return x

In [16]: %timeit ne.evaluate('(X>=0)+0')
10 loops, best of 3: 27.8 ms per loop

In [17]: %timeit (X>=0).view('i1')
100 loops, best of 3: 19.3 ms per loop

In [18]: %timeit relu_derivative_v1(X)
1 loop, best of 3: 269 ms per loop

In [19]: %timeit relu_derivative_v2(X)
1 loop, best of 3: 89.5 ms per loop

The numexpr based one was with 8 threads. Thus, with more number of threads available for compute, it should improve further. Related post on how to control multi-core functionality.

Approach #3 : Approach #1 + #2 -

Mix both of those for the most optimal one for large arrays -

In [27]: np.random.seed(0)
    ...: X = np.random.randn(3072,10000)

In [28]: %timeit ne.evaluate('X>=0').view('i1')
100 loops, best of 3: 14.7 ms per loop


来源:https://stackoverflow.com/questions/54969120/faster-implementation-for-relu-derivative-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!