What does the “source hidden state” refer to in the Attention Mechanism?

爱⌒轻易说出口 提交于 2020-01-24 10:29:06

问题


The attention weights are computed as:

I want to know what the h_s refers to.

In the tensorflow code, the encoder RNN returns a tuple:

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(...)

As I think, the h_s should be the encoder_state, but the github/nmt gives a different answer?

# attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
    num_units, attention_states,
    memory_sequence_length=source_sequence_length)

Did I misunderstand the code? Or the h_s actually means the encoder_outputs?


回答1:


The formula is probably from this post, so I'll use a NN picture from the same post:

Here, the h-bar(s) are all the blue hidden states from the encoder (the last layer), and h(t) is the current red hidden state from the decoder (also the last layer). One the picture t=0, and you can see which blocks are wired to the attention weights with dotted arrows. The score function is usually one of those:


Tensorflow attention mechanism matches this picture. In theory, cell output is in most cases its hidden state (one exception is LSTM cell, in which the output is the short-term part of the state, and even in this case the output suits better for attention mechanism). In practice, tensorflow's encoder_state is different from encoder_outputs when the input is padded with zeros: the state is propagated from the previous cell state while the output is zero. Obviously, you don't want to attend to trailing zeros, so it makes sense to have h-bar(s) for these cells.

So encoder_outputs are exactly the arrows that go from the blue blocks upward. Later in a code, attention_mechanism is connected to each decoder_cell, so that its output goes through the context vector to the yellow block on the picture.

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
    decoder_cell, attention_mechanism,
    attention_layer_size=num_units)


来源:https://stackoverflow.com/questions/48394009/what-does-the-source-hidden-state-refer-to-in-the-attention-mechanism

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!