pytorch nn.LSTM()参数详解
输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num_layers * num_directions, batch, hidden_size) 输出数据格式: output(seq_len, batch, hidden_size * num_directions) hn(num_layers * num_directions, batch, hidden_size) cn(num_layers * num_directions, batch, hidden_size) import torch import torch.nn as nn from torch.autograd import Variable #构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size) rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers) h0 = torch.randn(2,3,20) ->(num