Simple example of CuDnnGRU based RNN implementation in Tensorflow

随声附和 提交于 2019-11-30 14:33:18

CudnnGRU is not an RNNCell instance. It's more akin to dynamic_rnn.

The tensor manipulations below are equivalent, where input_tensor is a time-major tensor, i.e. of shape [max_sequence_length, batch_size, embedding_size]. CudnnGRU expects the input tensor to be time-major (as opposed to the more standard batch-major format i.e. of shape [batch_size, max_sequence_length, embedding_size]), and it's a good practice to use time-major tensors with RNN ops anyways since they're somewhat faster.

CudnnGRU:

rnn = tf.contrib.cudnn_rnn.CudnnGRU(
  num_rnn_layers, hidden_size, direction='bidirectional')

rnn_output = rnn(input_tensor)

CudnnCompatibleGRUCell:

rnn_output = input_tensor
sequence_length = tf.reduce_sum(
  tf.sign(inputs),
  reduction_indices=0)  # 1 if `input_tensor` is batch-major.

  for _ in range(num_rnn_layers):
    fw_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(hidden_size)
    bw_cell = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(hidden_size)
    rnn_output = tf.nn.bidirectional_dynamic_rnn(
      fw_cell, bw_cell, rnn_output, sequence_length=sequence_length,
      dtype=tf.float32, time_major=True)[1]  # Set `time_major` accordingly

Note the following:

  1. If you were using LSTMs, you need not use CudnnCompatibleLSTMCell; you can use the standard LSTMCell. But with GRUs, the Cudnn implementation has inherently different math operations, and in particular, more weights (see the documentation).
  2. Unlike dynamic_rnn, CudnnGRU doesn't allow you to specify sequence lengths. Still, it is over an order of magnitude faster, but you will have to be careful on how you extract your outputs (e.g. if you're interested in the final hidden state of each sequence that is padded and of varying length, you will need each sequence's length).
  3. rnn_output is probably a tuple with lots of (distinct) stuff in both cases. Refer to the documentation, or just print it out, to inspect what parts of the output you need.
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!