问题
at the moment I'm looking at the cifar10 example and I noticed the function _variable_with_weight_decay(...) in the file cifar10.py. The code is as follows:
def _variable_with_weight_decay(name, shape, stddev, wd):
"""Helper to create an initialized Variable with weight decay.
Note that the Variable is initialized with a truncated normal distribution.
A weight decay is added only if one is specified.
Args:
name: name of the variable
shape: list of ints
stddev: standard deviation of a truncated Gaussian
wd: add L2Loss weight decay multiplied by this float. If None, weight
decay is not added for this Variable.
Returns:
Variable Tensor
"""
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
var = _variable_on_cpu(
name,
shape,
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
I'm wondering if this function does what it says. It is clear that when a weight decay factor is given (wd not None) the deacy value (weight_decay) is computed. But is it every applied? At the end the unmodified variable (var) is return, or am I missing something?
Second question would be how to fix this? As I understand the value of the scalar weight_decay must be subtracted from each element in the weight matrix, but I'm unable to find a tensorflow op that can do that (adding/subtracting a single value from every element of a tensor). Is there any op like this? As a workaround I thought it might be possible to create a new tensor initialized with the value of weight_decay and use tf.subtract(...) to achieve the same result. Or is this the right way to go anyway?
Thanks in advance.
回答1:
The code does what it says. You are supposed to sum everything in the 'losses' collection (which the weight decay term is added to in the second to last line) for the loss that you pass to the optimizer. In the loss() function in that example:
tf.add_to_collection('losses', cross_entropy_mean)
[...]
return tf.add_n(tf.get_collection('losses'), name='total_loss')
so what the loss() function returns is the classification loss plus everything that was in the 'losses' collection before.
As a side note, weight decay does not mean you subtract the value of wd from every value in the tensor as part of the update step, it multiplies the value by (1-learning_rate*wd) (in plain SGD). To see why this is so, recall that l2_loss computes
output = sum(t_i ** 2) / 2
with t_i being the elements of the tensor. This means that the derivative of l2_loss with regard to each tensor element is the value of that tensor element itself, and since you scaled l2_loss with wd the derivative is scaled as well.
Since the update step (again, in plain SGD) is (forgive me for omitting the time step indexes)
w := w - learning_rate * dL/dw
you get, if you only had the weight decay term
w := w - learning_rate * wd * w
or
w := w * (1 - learning_rate * wd)
来源:https://stackoverflow.com/questions/41714801/tensorflow-variable-with-weight-decay-explanation