Selectively zero weights in TensorFlow?

痞子三分冷 提交于 2019-12-21 16:03:01

问题


Lets say I have an NxM weight variable weights and a constant NxM matrix of 1s and 0s mask.

If a layer of my network is defined like this (with other layers similarly defined):

masked_weights = mask*weights
layer1 = tf.relu(tf.matmul(layer0, masked_weights) + biases1)

Will this network behave as if the corresponding 0s in mask are zeros in weights during training? (i.e. as if the connections represented by those weights had been removed from the network entirely)?

If not, how can I achieve this goal in TensorFlow?


回答1:


The answer is yes. The experiment depicts the following graph.

The implementation is:

import numpy as np, scipy as sp, tensorflow as tf

x = tf.placeholder(tf.float32, shape=(None, 3))
weights = tf.get_variable("weights", [3, 2])
bias = tf.get_variable("bias", [2])
mask = tf.constant(np.asarray([[0, 1], [1, 0], [0, 1]], dtype=np.float32)) # constant mask

masked_weights = tf.multiply(weights, mask)
y = tf.nn.relu(tf.nn.bias_add(tf.matmul(x, masked_weights), bias))
loss = tf.losses.mean_squared_error(tf.constant(np.asarray([[1, 1]], dtype=np.float32)),y)

weights_grad = tf.gradients(loss, weights)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("Masked weights=\n", sess.run(masked_weights))
data = np.random.rand(1, 3)

print("Graident of weights\n=", sess.run(weights_grad, feed_dict={x: data}))
sess.close()

After running the code above, you will see the gradients are masked as well. In my example, they are:

Graident of weights
= [array([[ 0.        , -0.40866762],
       [ 0.34265977, -0.        ],
       [ 0.        , -0.35294518]], dtype=float32)]



回答2:


The answer is yes and the reason lies in backpropogation as explained below.

mask_w = mask * w

del(mask_w) = mask * del(w).

The mask will make the gradient 0 wherever its value is zero. Wherever its value is 1, gradient will flow as previously. This is a common trick used in seq2seq predictions to mask the different size output in decoding layer. You can read more about this here.



来源:https://stackoverflow.com/questions/38278965/selectively-zero-weights-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!