How to freeze/lock weights of one TensorFlow variable (e.g., one CNN kernel of one layer)

有些话、适合烂在心里 提交于 2019-12-03 04:36:22

问题


I have a TensorFlow CNN model that is performing well and we would like to implement this model in hardware; i.e., an FPGA. It's a relatively small network but it would be ideal if it were smaller. With that goal, I've examined the kernels and find that there are some where the weights are quite strong and there are others that aren't doing much at all (the kernel values are all close to zero). This occurs specifically in layer 2, corresponding to the tf.Variable() named, "W_conv2". W_conv2 has shape [3, 3, 32, 32]. I would like to freeze/lock the values of W_conv2[:, :, 29, 13] and set them to zero so that the rest of the network can be trained to compensate. Setting the values of this kernel to zero effectively removes/prunes the kernel from the hardware implementation thus achieving the goal stated above.

I have found similar questions with suggestions that generally revolve around one of two approaches;

Suggestion #1:

    tf.Variable(some_initial_value, trainable = False)

Implementing this suggestion freezes the entire variable. I want to freeze just a slice, specifically W_conv2[:, :, 29, 13].

Suggestion #2:

    Optimizer = tf.train.RMSPropOptimizer(0.001).minimize(loss, var_list)

Again, implementing this suggestion does not allow the use of slices. For instance, if I try the inverse of my stated goal (optimize only a single kernel of a single variable) as follows:

    Optimizer = tf.train.RMSPropOptimizer(0.001).minimize(loss, var_list = W_conv2[:,:,0,0]))

I get the following error:

    NotImplementedError: ('Trying to optimize unsupported type ', <tf.Tensor 'strided_slice_2228:0' shape=(3, 3) dtype=float32>)

Slicing tf.Variables() isn't possible in the way that I've tried it here. The only thing that I've tried which comes close to doing what I want is using .assign() but this is extremely inefficient, cumbersome, and caveman-like as I've implemented it as follows (after the model is trained):

    for _ in range(10000):
        # get a new batch of data
        # reset the values of W_conv2[:,:,29,13]=0 each time through
        for m in range(3):
            for n in range(3):
                assign_op = W_conv2[m,n,29,13].assign(0)
                sess.run(assign_op)
        # re-train the rest of the network
        _, loss_val = sess.run([optimizer, loss], feed_dict = {
                                   dict_stuff_here
                               })
        print(loss_val)

The model was started in Keras then moved to TensorFlow since Keras didn't seem to have a mechanism to achieve the desired results. I'm starting to think that TensorFlow doesn't allow for pruning but find this hard to believe; it just needs the correct implementation.


回答1:


A possible approach is to initialize these specific weights with zeros, and modify the minimization process such that gradients won't be applied to them. It can be done by replacing the call to minimize() with something like:

W_conv2_weights = np.ones((3, 3, 32, 32))
W_conv2_weights[:, :, 29, 13] = 0
W_conv2_weights_const = tf.constant(W_conv2_weights)

optimizer = tf.train.RMSPropOptimizer(0.001)

W_conv2_orig_grads = tf.gradients(loss, W_conv2)
W_conv2_grads = tf.multiply(W_conv2_weights_const, W_conv2_orig_grads)
W_conv2_train_op = optimizer.apply_gradients(zip(W_conv2_grads, W_conv2))

rest_grads = tf.gradients(loss, rest_of_vars)
rest_train_op = optimizer.apply_gradients(zip(rest_grads, rest_of_vars))

tf.group([rest_train_op, W_conv2_train_op])

I.e,

  1. Preparing a constant Tensor for canceling the appropriate gradients
  2. Compute gradients only for W_conv2, then multiply element-wise with the constant W_conv2_weights to zero the appropriate gradients and only then apply gradients.
  3. Compute and apply gradients "normally" to the rest of the variables.
  4. Group the 2 train ops to a single training op.


来源:https://stackoverflow.com/questions/42517926/how-to-freeze-lock-weights-of-one-tensorflow-variable-e-g-one-cnn-kernel-of-o

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!