学习pytorch的小记录

ぐ巨炮叔叔 提交于 2020-03-17 09:47:56

原教程地址

PyTorch简明教程

1.PyTorch神经网络简介

1.1 计算梯度

那里是有一点不太明白的记录一下。
在调用loss.backward()之前,我们需要清除掉tensor里之前的梯度,否则会累加进去。

net.zero_grad()     # 清掉tensor里缓存的梯度值。

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

这一段的

loss.backward()

这个语句的功能在学习的时候是不太明白的,以后懂了再来填坑

1.2 这一部分的可运行代码贴在下面

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #输入是1个通道的灰度图,输出6个通道(feature map),使用5*5的卷积核
        self.conv1 = nn.Conv2d(1, 6, 5)
        # 第二个卷积层也是5x5,有16个通道
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 全连接层
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 32x32 -> 28x28 -> 14x14
        x = F.max_pool2d(F.relu(self.conv1(x)),(2, 2))
        # 14x14 -> 10x10 -> 5x5
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:] # 除了batch维度之外的其它维度。
        num_features = 1
        for s in size:
            num_features *= s
        return num_features



net = Net()
#print(net)


'''我们只需要定义forward函数,而backward函数会自动通过autograd创建。
在forward函数里可以使用任何处理Tensor的函数。
我们可以使用函数net.parameters()来得到模型所有的参数。'''
#params = list(net.parameters())
#print(len(params))
#print(params[0].size())  # conv1的weight

'''测试网络
接着我们尝试一个随机的32x32的输入来检验(sanity check)网络定义没有问题。
注意:这个网络(LeNet)期望的输入大小是32x32。如果使用MNIST数据集(28x28),我们需要缩放到32x32。'''
input = torch.randn(1, 1, 32, 32)
#out = net(input)
#print(out)
'''默认的梯度会累加,因此我们通常在backward之前清除掉之前的梯度值:'''
#net.zero_grad()
#out.backward(torch.randn(1, 10))
'''注意:torch.nn只支持mini-batches的输入。
整个torch.nn包的输入都必须第一维是batch,即使只有一个样本也要弄成batch是1的输入。'''

'''比如,nn.Conv2d的输入是一个4D的Tensor,shape是nSamples x nChannels x Height x Width。
如果你只有一个样本(nChannels x Height x Width),那么可以使用input.unsqueeze(0)来增加一个batch维。'''

output = net(input)
target = torch.arange(1, 11)  # 随便伪造的一个“真实值”
target = target.view(1, -1)  # 把它变成output的shape(1, 10)
criterion = nn.MSELoss()

loss = criterion(output, target.float())
#print(loss)


#print(loss.grad_fn)  # MSELoss
#print(loss.grad_fn.next_functions[0][0])  # Add
#print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # Expand

  # 清掉tensor里缓存的梯度值。

net.zero_grad()
print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)


'''更新参数最简单的方法是使用随机梯度下降(SGD): weight=weight−learningrate∗gradient 
我们可以使用如下简单的代码来实现更新:'''
#learning_rate = 0.01
#for f in net.parameters():
#	f.data.sub_(f.grad.data * learning_rate)


'''通常我们会使用更加复杂的优化方法,比如SGD, Nesterov-SGD, Adam, RMSProp等等。
为了实现这些算法,我们可以使用torch.optim包,它的用法也非常简单:'''
import torch.optim as optim

# 创建optimizer,需要传入参数和learning rate
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 清除梯度
optimizer.zero_grad()
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()    # optimizer会自动帮我们更新参数

2.训练一个分类器

2.1 如何进行数据处理

一般地,当我们处理图片、文本、音频或者视频数据的时候,我们可以使用python代码来把它转换成numpy数组。然后再把numpy数组转换成torch.xxxTensor。

  • 对于处理图像,常见的lib包括Pillow和OpenCV
  • 对于音频,常见的lib包括scipy和librosa
  • 对于文本,可以使用标准的Python库,另外比较流行的lib包括NLTK和SpaCy

对于视觉问题,PyTorch提供了一个torchvision包(需要单独安装),它对于常见数据集比如Imagenet, CIFAR10, MNIST等提供了加载的方法。并且它也提供很多数据变化的工具,包括torchvision.datasets和torch.utils.data.DataLoader。这会极大的简化我们的工作,避免重复的代码。

2.2 训练的步骤

  • 使用torchvision加载和预处理CIFAR10训练和测试数据集。
  • 定义卷积网络
  • 定义损失函数
  • 用训练数据训练模型
  • 用测试数据测试模型

2.3 数据处理

直接按教程里给的代码拼在一起是不能直接运行的,会报错。拼起来如下。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
	[transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/path/to/data', train=True,
	download=True, transform=transform)
	trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
	shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='/path/to/data', train=False,
	download=True, transform=transform)
	testloader = torch.utils.data.DataLoader(testset, batch_size=4,
	shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

import matplotlib.pyplot as plt
import numpy as np

# 显示图片的函数
def imshow(img):
img = img / 2 + 0.5     #  [-1,1] -> [0,1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # (channel, width, height) -> (width, height, channel)


# 随机选择一些图片
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 显示图片
imshow(torchvision.utils.make_grid(images))
# 打印label
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

可运行代码如下

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import sys

# 显示图片的函数
def imshow(img):
    img = img / 2 + 0.5     #  [-1,1] -> [0,1]
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0))) # (channel, width, height) -> (width, height, channel)
def main(argv = None):
    transform = transforms.Compose(
    [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root=r'C:\Users\lyy', train = True,download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root=r'C:\Users\lyy', train=False,download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # 随机选择一些图片
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    # 显示图片
    imshow(torchvision.utils.make_grid(images))
    plt.show()
    # 打印label
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))


if __name__=='__main__':
    sys.exit(main())

目前暂时后面的还没看完~

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!