Predicting the next word using the LSTM ptb model tensorflow example

前端 未结 2 1525
我在风中等你
我在风中等你 2020-12-09 20:06

I am trying to use the tensorflow LSTM model to make next word predictions.

As described in this related question (which has no accepted answer) the example contains

相关标签:
2条回答
  • 2020-12-09 21:05

    I am implementing seq2seq model too.

    So lets me try to explain with my understanding:

    The outputs of your LSTM model is a list (with length num_steps) of 2D tensor of size [batch_size, size].

    The code line:

    output = tf.reshape(tf.concat(1, outputs), [-1, size])

    will produce a new output which is a 2D tensor of size [batch_size x num_steps, size].

    For your case, batch_size = 1 and num_steps = 20 --> output shape is [20, size].

    Code line:

    logits = tf.nn.xw_plus_b(output, tf.get_variable("softmax_w", [size, vocab_size]), tf.get_variable("softmax_b", [vocab_size]))

    <=> output[batch_size x num_steps, size] x softmax_w[size, vocab_size] will output logits of size [batch_size x num_steps, vocab_size].
    For your case, logits of size [20, vocab_size] --> probs tensor has same size as logits by [20, vocab_size].

    Code line:

    chosen_word = np.argmax(probs, 1)

    will output chosen_word tensor of size [20, 1] with each value is the next prediction word index of current word.

    Code line:

    loss = seq2seq.sequence_loss_by_example([logits], [tf.reshape(self._targets, [-1])], [tf.ones([batch_size * num_steps])])

    is to compute the softmax cross entropy loss for batch_size of sequences.

    0 讨论(0)
  • The output tensor contains the concatentation of the LSTM cell outputs for each timestep (see its definition here). Therefore you can find the prediction for the next word by taking chosen_word[-1] (or chosen_word[sequence_length - 1] if the sequence has been padded to match the unrolled LSTM).

    The tf.nn.sparse_softmax_cross_entropy_with_logits() op is documented in the public API under a different name. For technical reasons, it calls a generated wrapper function that does not appear in the GitHub repository. The implementation of the op is in C++, here.

    0 讨论(0)
提交回复
热议问题