Differentiable round function in Tensorflow?

前端 未结 6 1121
面向向阳花
面向向阳花 2021-02-04 20:27

So the output of my network is a list of propabilities, which I then round using tf.round() to be either 0 or 1, this is crucial for this project. I then found out that tf.roun

6条回答
  •  萌比男神i
    2021-02-04 21:01

    Building on a previous answer, a way to get an arbitrarily good approximation is to approximate round() using a finite Fourier approximation and use as many terms as you need. Fundamentally, you can think of round(x) as adding a reverse (i. e. descending) sawtooth wave to x. So, using the Fourier expansion of the sawtooth wave we get

    round(x) ≈ x + 1/π ∑_n^N (-1)^n sin(2π n x)/n

    With N = 5, we get a pretty nice approximation:

提交回复
热议问题