Tensorflow, best way to save state in RNNs?

后端 未结 3 1660
慢半拍i
慢半拍i 2020-11-27 03:39

I currently have the following code for a series of chained together RNNs in tensorflow. I am not using MultiRNN since I was to do something later on with the output of eac

3条回答
  •  -上瘾入骨i
    2020-11-27 04:20

    I am now saving the RNN states using the tf.control_dependencies. Here is an example.

     saved_states = [tf.get_variable('saved_state_%d' % i, shape = (BATCH_SIZE, sz), trainable = False, initializer = tf.constant_initializer()) for i, sz in enumerate(rnn.state_size)]
            W = tf.get_variable('W', shape = (2 * RNN_SIZE, RNN_SIZE), initializer = tf.truncated_normal_initializer(0.0, 1 / np.sqrt(2 * RNN_SIZE)))
            b = tf.get_variable('b', shape = (RNN_SIZE,), initializer = tf.constant_initializer())
    
            rnn_output, states = rnn(last_output, saved_states)
            with tf.control_dependencies([tf.assign(a, b) for a, b in zip(saved_states, states)]):
                dense_input = tf.concat(1, (last_output, rnn_output))
    
            dense_output = tf.tanh(tf.matmul(dense_input, W) + b)
            last_output = dense_output + last_output
    

    I just make sure that part of my graph is dependent on saving the state.

提交回复
热议问题