tensorflow: how come gather_nd is differentiable?

折月煮酒 提交于 2019-12-01 14:42:41

问题


I'm looking at a tensorflow network implementing reinforcement-learning for the CartPole open-ai env.

The network implements the likelihood ratio approach for a policy gradient agent.

The thing is, that the policy loss is defined using the gather_nd op!! here, look:

    ....
    self.y = tf.nn.softmax(tf.matmul(self.W3,self.h2) + self.b3,dim=0)
    self.curr_reward = tf.placeholder(shape=[None],dtype=tf.float32)
    self.actions_array = tf.placeholder(shape=[None,2],dtype=tf.int32)
    self.pai_array = tf.gather_nd(self.y,self.actions_array)
    self.L = -tf.reduce_mean(tf.log(self.pai_array)*self.curr_reward)

And then they take the derivative of this loss with respect to all the parameters of the network:

    self.gradients = tf.gradients(self.L,tf.trainable_variables())

How can this be?? I thought that the whole point in neural networks is always working with differentiable ops, like cross-entropy and never do something strange like selecting indexes of self.y according to some self.actions_array selected by random and clearly not differentiable.

What am I missing here? thanks!


回答1:


The gradient is one if the parameter is gathered and zero if it is not. One use-case for the gather operator is to act like a sparse one-hot matrix multiplication. The second argument is the dense representation of the sparse matrix and you "multiply" it with the first argument by just selecting the right rows.




回答2:


There is no official documentation on this but according to this issue: https://github.com/tensorflow/models/issues/295 gradient of tf.gather in tensorflow implementation is 1 w.r.t to self.y and 0 w.r.t to index. Therefore, it will not propabagate gradient through index




回答3:


It's only differentiable w.r.t. self.y but not the integer/discrete elements of self.actions_array.



来源:https://stackoverflow.com/questions/45701722/tensorflow-how-come-gather-nd-is-differentiable

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!