Tensorflow: _variable_with_weight_decay(…) explanation

好久不见. 提交于 2020-01-04 04:26:07

问题


at the moment I'm looking at the cifar10 example and I noticed the function _variable_with_weight_decay(...) in the file cifar10.py. The code is as follows:

def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.
  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.
  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.
  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var

I'm wondering if this function does what it says. It is clear that when a weight decay factor is given (wd not None) the deacy value (weight_decay) is computed. But is it every applied? At the end the unmodified variable (var) is return, or am I missing something?

Second question would be how to fix this? As I understand the value of the scalar weight_decay must be subtracted from each element in the weight matrix, but I'm unable to find a tensorflow op that can do that (adding/subtracting a single value from every element of a tensor). Is there any op like this? As a workaround I thought it might be possible to create a new tensor initialized with the value of weight_decay and use tf.subtract(...) to achieve the same result. Or is this the right way to go anyway?

Thanks in advance.


回答1:


The code does what it says. You are supposed to sum everything in the 'losses' collection (which the weight decay term is added to in the second to last line) for the loss that you pass to the optimizer. In the loss() function in that example:

tf.add_to_collection('losses', cross_entropy_mean)
[...]
return tf.add_n(tf.get_collection('losses'), name='total_loss')

so what the loss() function returns is the classification loss plus everything that was in the 'losses' collection before.

As a side note, weight decay does not mean you subtract the value of wd from every value in the tensor as part of the update step, it multiplies the value by (1-learning_rate*wd) (in plain SGD). To see why this is so, recall that l2_loss computes

output = sum(t_i ** 2) / 2

with t_i being the elements of the tensor. This means that the derivative of l2_loss with regard to each tensor element is the value of that tensor element itself, and since you scaled l2_loss with wd the derivative is scaled as well.

Since the update step (again, in plain SGD) is (forgive me for omitting the time step indexes)

w := w - learning_rate * dL/dw

you get, if you only had the weight decay term

w := w - learning_rate * wd * w

or

w := w * (1 - learning_rate * wd)


来源:https://stackoverflow.com/questions/41714801/tensorflow-variable-with-weight-decay-explanation

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!