What's the difference between “hidden” and “output” in PyTorch LSTM?

后端 未结 4 1389
孤街浪徒
孤街浪徒 2020-12-12 09:47

I\'m having trouble understanding the documentation for PyTorch\'s LSTM module (and also RNN and GRU, which are similar). Regarding the outputs, it says:

相关标签:
4条回答
  • 2020-12-12 10:11

    I made a diagram. The names follow the PyTorch docs, although I renamed num_layers to w.

    output comprises all the hidden states in the last layer ("last" depth-wise, not time-wise). (h_n, c_n) comprises the hidden states after the last timestep, t = n, so you could potentially feed them into another LSTM.

    The batch dimension is not included.

    0 讨论(0)
  • 2020-12-12 10:13

    I just verified some of this using code, and its indeed correct that if it's a depth 1 LSTM, then h_n is the same as the last value of the "output". (this will not be true for > 1 depth LSTM though as explained above by @nnnmmm)

    So, basically the "output" we get after applying LSTM is not the same as o_t as defined in the documentation, rather it is h_t.

    import torch
    import torch.nn as nn
    
    torch.manual_seed(0)
    model = nn.LSTM( input_size = 1, hidden_size = 50, num_layers  = 1 )
    x = torch.rand( 50, 1, 1)
    output, (hn, cn) = model(x)
    

    Now one can check that output[-1] and hn both have the same value as follows

    tensor([[ 0.1140, -0.0600, -0.0540,  0.1492, -0.0339, -0.0150, -0.0486,  0.0188,
              0.0504,  0.0595, -0.0176, -0.0035,  0.0384, -0.0274,  0.1076,  0.0843,
             -0.0443,  0.0218, -0.0093,  0.0002,  0.1335,  0.0926,  0.0101, -0.1300,
             -0.1141,  0.0072, -0.0142,  0.0018,  0.0071,  0.0247,  0.0262,  0.0109,
              0.0374,  0.0366,  0.0017,  0.0466,  0.0063,  0.0295,  0.0536,  0.0339,
              0.0528, -0.0305,  0.0243, -0.0324,  0.0045, -0.1108, -0.0041, -0.1043,
             -0.0141, -0.1222]], grad_fn=<SelectBackward>)
    
    0 讨论(0)
  • 2020-12-12 10:17

    It really depends on a model you use and how you will interpret the model. Output may be:

    • a single LSTM cell hidden state
    • several LSTM cell hidden states
    • all the hidden states outputs

    Output, is almost never interpreted directly. If the input is encoded there should be a softmax layer to decode the results.

    Note: In language modeling hidden states are used to define the probability of the next word, p(wt+1|w1,...,wt) =softmax(Wht+b).

    0 讨论(0)
  • 2020-12-12 10:20

    The output state is the tensor of all the hidden state from each time step in the RNN(LSTM), and the hidden state returned by the RNN(LSTM) is the last hidden state from the last time step from the input sequence. You could check this by collecting all of the hidden states from each step and comparing that to the output state,(provided you are not using pack_padded_sequence).

    0 讨论(0)
提交回复
热议问题