Avoid overflow with softplus function in python

前端 未结 4 758
不思量自难忘°
不思量自难忘° 2021-02-14 05:20

I am trying to implement the following softplus function:

log(1 + exp(x))

I\'ve tried it with math/numpy and float64 as data type, but whenever

4条回答
  •  被撕碎了的回忆
    2021-02-14 06:08

    Since for x>30 we have log(1+exp(x)) ~= log(exp(x)) = x, a simple stable implementation is

    def safe_softplus(x, limit=30):
      if x>limit:
        return x
      else:
        return np.log1p(np.exp(x))
    

    In fact | log(1+exp(30)) - 30 | < 1e-10, so this implementation makes errors smaller than 1e-10 and never overflows. In particular for x=1000 the error of this approximation will be much smaller than float64 resolution, so it is impossible to even measure it on the computer.

提交回复
热议问题