Does anyone know how to update a subset (i.e. only some indices) of the weights that are used in the forward propagation?
My guess is that I might be able to do that
You could use a combination of gather
and scatter_update
. Here's an example that doubles the values at position 0
and 2
indices = tf.constant([0,2])
data = tf.Variable([1,2,3])
data_subset = tf.gather(data, indices)
updated_data_subset = 2*data_subset
sparse_update = tf.scatter_update(data, indices, updated_data_subset)
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
sess.run([sparse_update])
print "Values after:", sess.run([data])
You should see
Values before: [array([1, 2, 3], dtype=int32)]
Values after: [array([2, 2, 6], dtype=int32)]
Easiest way is to pull the tf.Variable
into python (as a numpy array) using npvar = sess.run(tfvar)
, then perform some operation on it such as npvar[1, 2] = -10
. Then you can upload the modified data back into tensorflow using sess.run(tfvar.assign(npvar))
.
Obviously this is very slow and not really useful for training but it does work.