What is the equivalent of the rnn() in TensorFLow r 1.0?

匿名 (未验证) 提交于 2019-12-03 01:27:01

问题:

I used to create the RNN network in v0.8, using:

from tensorflow.python.ops import rnn  # Define a lstm cell with tensorflow lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)  # Get lstm cell output outputs, states = rnn.rnn(cell=lstm_cell, inputs=x, dtype=tf.float32) 

rnn.rnn() is not available anymore, and it sounds it has been moved to tf.contrib. What is the exact code to create RNN network out of a BasicLSTMCell?

Or, in the case that I have an stacked LSTM,

lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_size, forget_bias=0.0) stacked_lstm = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers) outputs, new_state =  tf.nn.rnn(stacked_lstm, inputs, initial_state=_initial_state) 

So, what is the replacement for tf.nn.rnn() in TensorFlow r1.0 ?

回答1:

Updated Response

Nov 2017, per lum's comment below: Note that since version 1.2, static_rnn is back in the "main" namespace, and you should now use tf.nn.static_rnn.


Original Response

I have also similar (language modeling) code using LSTMs that I attempted to update to TF 1.0 using tf.nn.dynamic_rnn but I ran into problems, specifically:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension must be 2 but is 3 for 'Model/LSTM/transpose' (op: 'Transpose') with input shapes: [20,1100], [3].

What worked for me was replacing tf.nn.rnn with tf.contrib.rnn.static_rnn, notice the signatures are identical.

my working code diffs:

- outputs, final_rnn_state = tf.nn.rnn(cell, input_cnn2,initial_state=initial_rnn_state, dtype=tf.float32) + outputs, final_rnn_state = tf.contrib.rnn.static_rnn(cell, input_cnn2, initial_state=initial_rnn_state, dtype=tf.float32) 


回答2:

You should use tf.nn.dynamic_rnn.

FYI: What is the upside of using tf.nn.rnn instead of tf.nn.dynamic_rnn in TensorFlow?



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!