Tensorflow RNN weight matrices initialization

匿名 (未验证) 提交于 2019-12-03 01:47:02

问题:

I'm using bidirectional_rnn with GRUCell but this is a general question regarding the RNN in Tensorflow.

I couldn't find how to initialize the weight matrices (input to hidden, hidden to hidden). Are they initialized randomly? to zeros? are they initialized differently for each LSTM I create?

EDIT: Another motivation for this question is in pre-training some LSTMs and using their weights in a subsequent model. I don't currently know how to do that currently without saving all the states and restoring the entire model.

Thanks.

回答1:

How to initialize weight matrices for RNN?

I believe people are using random normal initialization for weight matrices for RNN. Check out the example in TensorFlow GitHub Repo. As the notebook is a bit long, they have a simple LSTM model where they use tf.truncated_normal to initialize weights and tf.zeros to initialize biases (although I have tried using tf.ones to initialize biases before, seem to also work). I believe that the standard deviation is a hyperparameter you could tune yourself. Sometimes weights initialization is important to the gradient flow. Although as far as I know, LSTM itself is designed to handle gradient vanishing problem (and gradient clipping is for helping gradient exploding problem), so perhaps you don't need to be super careful with the setup of std_dev in LSTM? I've read papers recommending Xavier initialization (TF API doc for Xavier initializer) in Convolution Neural Network context. I don't know if people use that in RNN, but I imagine you can even try those in RNN if you want to see if it helps.

Now to follow up with @Allen's answer and your follow up question left in the comments.

How to control initialization with variable scope?

Using the simple LSTM model in the TensorFlow GitHub python notebook that I linked to as an example. Specifically, if I want to re-factorize the LSTM part of the code in above picture using variable scope control, I may code something as following...

import tensorflow as tf def initialize_LSTMcell(vocabulary_size, num_nodes, initializer):     '''initialize LSTMcell weights and biases, set variables to reuse mode'''     gates = ['input_gate', 'forget_gate', 'memory_cell', 'output_gate']     with tf.variable_scope('LSTMcell') as scope:         for gate in gates:             with tf.variable_scope(gate) as gate_scope:                 wx = tf.get_variable("wx", [vocabulary_size, num_nodes], initializer)                 wt = tf.get_variable("wt", [num_nodes, num_nodes], initializer)                 bi = tf.get_variable("bi", [1, num_nodes, tf.constant_initializer(0.0)])                 gate_scope.reuse_variables() #this line can probably be omitted, b.z. by setting 'LSTMcell' scope variables to 'reuse' as the next line, it'll turn on the reuse mode for all its child scope variables         scope.reuse_variables()  def get_scope_variables(scope_name, variable_names):     '''a helper function to fetch variable based on scope_name and variable_name'''     vars = {}     with tf.variable_scope(scope_name, reuse=True):         for var_name in variable_names             var = tf.get_variable(var_name)             vars[var_name] = var     return vars  def LSTMcell(i, o, state):     '''a function for performing LSTMcell computation'''     gates = ['input_gate', 'forget_gate', 'memory_cell', 'output_gate']     var_names = ['wx', 'wt', 'bi']     gate_comp = {}     with tf.variable_scope('LSTMcell', reuse=True):         for gate in gates:             vars = get_scope_variables(gate, var_names)             gate_comp[gate] = tf.matmul(i, vars['wx']) + tf.matmul(o, vars['wt']) + vars['bi']     state = tf.sigmoid(gate_comp['forget_gate']) * state + tf.sigmoid(gate_comp['input_gate']) * tf.tanh(gate_comp['memory_cell'])     output = tf.sigmoid(gate_comp['output_gate']) * tf.tanh(state)     return output, state

The usage of the re-factorized code would be something like following...

initialize_LSTMcell(volcabulary_size, num_nodes, tf.truncated_normal_initializer(mean=-0.1, stddev=.01))  
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!