Recurrent neural network (RNN) - Pytorch版

匿名 (未验证) 提交于 2019-12-02 23:56:01
import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms  # 配置GPU或CPU设置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 超参数设置 sequence_length = 28 input_size = 28 hidden_size = 128 num_layers = 2 num_classes = 10 batch_size = 100 num_epochs = 2 learning_rate = 0.01  # MNIST dataset train_dataset = torchvision.datasets.MNIST(root='./data/',                                            train=True,                                            transform=transforms.ToTensor(),# PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255                                            download=True)  test_dataset = torchvision.datasets.MNIST(root='./data/',                                           train=False,                                           transform=transforms.ToTensor())# PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255  # 训练数据加载,按照batch_size大小加载,并随机打乱 train_loader = torch.utils.data.DataLoader(dataset=train_dataset,                                            batch_size=batch_size,                                            shuffle=True) # 测试数据加载,按照batch_size大小加载 test_loader = torch.utils.data.DataLoader(dataset=test_dataset,                                           batch_size=batch_size,                                           shuffle=False)   # Recurrent neural network (many-to-one) 多对一 class RNN(nn.Module):     def __init__(self, input_size, hidden_size, num_layers, num_classes):         super(RNN, self).__init__() # 继承 __init__ 功能         self.hidden_size = hidden_size         self.num_layers = num_layers         self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # if use nn.RNN(), it hardly learns  LSTM 效果要比 nn.RNN() 好多了         self.fc = nn.Linear(hidden_size, num_classes)      def forward(self, x):         # Set initial hidden and cell states         h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)         c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)          # Forward propagate LSTM         out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)          # Decode the hidden state of the last time step         out = self.fc(out[:, -1, :])         return out   model = RNN(input_size, hidden_size, num_layers, num_classes).to(device) print(model) # RNN((lstm): LSTM(28, 128, num_layers=2, batch_first=True) #     (fc): Linear(in_features=128, out_features=10, bias=True))  # 损失函数与优化器设置 # 损失函数 criterion = nn.CrossEntropyLoss() # 优化器设置 ,并传入RNN模型参数和相应的学习率 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # 训练模型 total_step = len(train_loader) for epoch in range(num_epochs):     for i, (images, labels) in enumerate(train_loader):         images = images.reshape(-1, sequence_length, input_size).to(device)         labels = labels.to(device)          # 前向传播         outputs = model(images)         # 计算损失 loss         loss = criterion(outputs, labels)          # 反向传播与优化         # 清空上一步的残余更新参数值         optimizer.zero_grad()         # 反向传播         loss.backward()         # 将参数更新值施加到RNN model的parameters上         optimizer.step()         # 每迭代一定步骤,打印结果值         if (i + 1) % 100 == 0:             print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'                    .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))  # 测试模型 with torch.no_grad():     correct = 0     total = 0     for images, labels in test_loader:         images = images.reshape(-1, sequence_length, input_size).to(device)         labels = labels.to(device)         outputs = model(images)         _, predicted = torch.max(outputs.data, 1)         total += labels.size(0)         correct += (predicted == labels).sum().item()      print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))  # 保存已经训练好的模型 # Save the model checkpoint torch.save(model.state_dict(), 'model.ckpt')

  

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!