Update a subset of weights in TensorFlow

后端 未结 2 1312
予麋鹿
予麋鹿 2020-12-16 06:21

Does anyone know how to update a subset (i.e. only some indices) of the weights that are used in the forward propagation?

My guess is that I might be able to do that

相关标签:
2条回答
  • 2020-12-16 07:16

    You could use a combination of gather and scatter_update. Here's an example that doubles the values at position 0 and 2

    indices = tf.constant([0,2])
    data = tf.Variable([1,2,3])
    data_subset = tf.gather(data, indices)
    updated_data_subset = 2*data_subset
    sparse_update = tf.scatter_update(data, indices, updated_data_subset)
    init_op = tf.initialize_all_variables()
    
    sess = tf.Session()
    sess.run([init_op])
    print "Values before:", sess.run([data])
    sess.run([sparse_update])
    print "Values after:", sess.run([data])
    

    You should see

    Values before: [array([1, 2, 3], dtype=int32)]
    Values after: [array([2, 2, 6], dtype=int32)]
    
    0 讨论(0)
  • 2020-12-16 07:18

    Easiest way is to pull the tf.Variable into python (as a numpy array) using npvar = sess.run(tfvar), then perform some operation on it such as npvar[1, 2] = -10. Then you can upload the modified data back into tensorflow using sess.run(tfvar.assign(npvar)).

    Obviously this is very slow and not really useful for training but it does work.

    0 讨论(0)
提交回复
热议问题