Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

匿名 (未验证) 提交于 2019-12-03 02:50:02

问题:

I have a setup where I need to initialize an LSTM after the main initialization which uses tf.initialize_all_variables(). I.e. I want to call tf.initialize_variables([var_list])

Is there way to collect all the internal trainable variables for both:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

so that I can initialize JUST these parameters?

The main reason I want this is because I do not want to re-initialize some trained values from earlier.

回答1:

The easiest way to solve your problem is to use variable scope. The names of the variables within a scope will be prefixed with its name. Here is a short snippet:

cell = rnn_cell.BasicLSTMCell(num_nodes)  with tf.variable_scope("LSTM") as vs:   # Execute the LSTM cell here in any way, for example:   for i in range(num_steps):     output[i], state = cell(input_data[i], state)    # Retrieve just the LSTM variables.   lstm_variables = [v for v in tf.all_variables()                     if v.name.startswith(vs.name)]  # [..] # Initialize the LSTM variables. tf.initialize_variables(lstm_variables) 

It would work the same way with MultiRNNCell.

EDIT: changed tf.trainable_variables to tf.all_variables()



回答2:

You can also use tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes) with tf.variable_scope("LSTM") as vs:   # Execute the LSTM cell here in any way, for example:   for i in range(num_steps):     output[i], state = cell(input_data[i], state)    lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 

(partly copied from Rafal's answer)

Note that the last line is equivalent to the list comprehension in Rafal's code.

Basically, tensorflow stores a global collection of variables, which can be fetched by either tf.all_variables() or tf.get_collection(tf.GraphKeys.VARIABLES). If you specify scope (scope name) in the tf.get_collection() function, then you only fetch tensors (variables in this case) in the collection whose scopes are under the specified scope.

EDIT: You can also use tf.GraphKeys.TRAINABLE_VARIABLES to get trainable variables only. But since vanilla BasicLSTMCell does not initialize any non-trainable variable, both will be functionally equivalent. For a complete list of default graph collections, check this out.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!