RuntimeWarning: invalid value encountered in greater

大城市里の小女人 提交于 2020-04-08 18:53:27

问题


I tried to implement soft-max with the following code (out_vec is a numpy vector of floats):

numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator

However, I got an overflow error because of np.exp(out_vec). Therefore, I checked (manually) what the upper limit of np.exp() is, and found that np.exp(709) is a number, but np.exp(710) is considered to be np.inf. Thus, to try to avoid the overflow error, I modified my code as follows:

out_vec[out_vec > 709] = 709 #prevent np.exp overflow
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator

Now, I get a different error:

RuntimeWarning: invalid value encountered in greater out_vec[out_vec > 709] = 709

What's wrong with the line I added? I looked up this specific error and all I found is people's advice on how to ignore the error. Simply ignoring the error won't help me, because every time my code encounters this error it does not give the usual results.


回答1:


Your problem is caused by the NaN or Inf elements in your out_vec array. You could use the following code to avoid this problem:

if np.isnan(np.sum(out_vec)):
    out_vec = out_vec[~numpy.isnan(out_vec)] # just remove nan elements from vector
out_vec[out_vec > 709] = 709
...

or you could use the following code to leave the NaN values in your array:

out_vec[ np.array([e > 709 if ~np.isnan(e) else False for e in out_vec], dtype=bool) ] = 709



回答2:


In my case the warning did not show up when calling this before the comparison (I had NaN values getting compared)

np.warnings.filterwarnings('ignore')



回答3:


IMO the better way would be to use a more numerically stable implementation of sum of exponentials.

from scipy.misc import logsumexp
out_vec = np.exp(out_vec - logsumexp(out_vec))



回答4:


If this happens because of your NaN value, then this might help:

out_vec[~np.isnan(out_vec)] = out_vec[~np.isnan(out_vec)] > 709

This does the greater operation for none NaN values and the rest remains the same. If you need the rest to be False, then do this too:

out_vec[np.isnan(out_vec)] = False


来源:https://stackoverflow.com/questions/37651803/runtimewarning-invalid-value-encountered-in-greater

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!