Pytorch实现手写数字mnist识别

匿名 (未验证) 提交于 2019-12-03 00:29:01
import torch import torchvision as tv import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import argparse # 定义是否使用GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义网络结构 class LeNet(nn.Module):     def __init__(self):         super(LeNet, self).__init__()         self.conv1 = nn.Sequential(     #input_size=(1*28*28)             nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同             nn.ReLU(),      #input_size=(6*28*28)             nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)         )         self.conv2 = nn.Sequential(             nn.Conv2d(6, 16, 5),             nn.ReLU(),      #input_size=(16*10*10)             nn.MaxPool2d(2, 2)  #output_size=(16*5*5)         )         self.fc1 = nn.Sequential(             nn.Linear(16 * 5 * 5, 120),             nn.ReLU()         )         self.fc2 = nn.Sequential(             nn.Linear(120, 84),             nn.ReLU()         )         self.fc3 = nn.Linear(84, 10)      # 定义前向传播过程,输入为x     def forward(self, x):         x = self.conv1(x)         x = self.conv2(x)         # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维         x = x.view(x.size()[0], -1)         x = self.fc1(x)         x = self.fc2(x)         x = self.fc3(x)         return x #使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多 parser = argparse.ArgumentParser() parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径 parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)")  #模型加载路径 opt = parser.parse_args()  # 超参数设置 EPOCH = 8   #遍历数据集次数 BATCH_SIZE = 64      #批处理尺寸(batch_size) LR = 0.001        #学习率  # 定义数据预处理方式 transform = transforms.ToTensor()  # 定义训练数据集 trainset = tv.datasets.MNIST(     root='./data/',     train=True,     download=True,     transform=transform)  # 定义训练批处理数据 trainloader = torch.utils.data.DataLoader(     trainset,     batch_size=BATCH_SIZE,     shuffle=True,     )  # 定义测试数据集 testset = tv.datasets.MNIST(     root='./data/',     train=False,     download=True,     transform=transform)  # 定义测试批处理数据 testloader = torch.utils.data.DataLoader(     testset,     batch_size=BATCH_SIZE,     shuffle=False,     )  # 定义损失函数loss function 和优化方式(采用SGD) net = LeNet().to(device) criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,通常用于多分类问题上 optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 训练 if __name__ == "__main__":      for epoch in range(EPOCH):         sum_loss = 0.0         # 数据读取         for i, data in enumerate(trainloader):             inputs, labels = data             inputs, labels = inputs.to(device), labels.to(device)              # 梯度清零             optimizer.zero_grad()              # forward + backward             outputs = net(inputs)             loss = criterion(outputs, labels)             loss.backward()             optimizer.step()              # 每训练100个batch打印一次平均loss             sum_loss += loss.item()             if i % 100 == 99:                 print('[%d, %d] loss: %.03f'                       % (epoch + 1, i + 1, sum_loss / 100))                 sum_loss = 0.0         # 每跑完一次epoch测试一下准确率         with torch.no_grad():             correct = 0             total = 0             for data in testloader:                 images, labels = data                 images, labels = images.to(device), labels.to(device)                 outputs = net(images)                 # 取得分最高的那个类                 _, predicted = torch.max(outputs.data, 1)                 total += labels.size(0)                 correct += (predicted == labels).sum()             print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))     #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1)) 
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!