打卡(一)

懵懂的女人 提交于 2020-02-14 17:44:15

线性回归

从零开始

  • 生成/准备数据集

  • 读取数据集

  • 初始化模型参数

  • 正态分布初始化:np.random.normal(loc(均值),scale(标准差),size(输出形状))

  • 全零初始化:np.zeros()

  • 定义模型

    • torch.mm(矩阵相乘)

    • torch.mul(点乘)

  • 定义损失函数

    • 均方损失函数

    • 交叉熵损失函数

  • 定义优化函数

  • 小批量随机梯度下降,参数-学习率*梯度/batch_size

  • 训练

  • 前向传播

  • 求损失

  • 反向传播

  • 优化参数

  • 梯度清零

pytorch简介实现

  • Dataloader使用

      import torch.utils.data as Data
    
      
    
      batch_size = 10
    
      
    
      # combine featues and labels of dataset
    
      dataset = Data.TensorDataset(features, labels)
    
      
    
      # put dataset into DataLoader
    
      data_iter = Data.DataLoader(
    
          dataset=dataset,            # torch TensorDataset format
    
          batch_size=batch_size,      # mini batch size
    
          shuffle=True,               # whether shuffle the data or not
    
          num_workers=2,              # read data in multithreading
    
      )
    
  • 定义模型

      class LinearNet(nn.Module):
    
          def __init__(self, n_feature):
    
              super(LinearNet, self).__init__()      # call father function to init 
    
              self.linear = nn.Linear(n_feature, 1)  # function prototype: `torch.nn.Linear(in_features, out_features, bias=True)`
    
      
    
          def forward(self, x):
    
              y = self.linear(x)
    
              return y
    
          
    
      net = LinearNet(num_inputs)
    
      print(net)
    
    • 三种方式添加层

        # ways to init a multilayer network
      
        # method one
      
        net = nn.Sequential(
      
            nn.Linear(num_inputs, 1)
      
            # other layers can be added here
      
            )
      
        
      
        # method two
      
        net = nn.Sequential()
      
        net.add_module('linear', nn.Linear(num_inputs, 1))
      
        # net.add_module ......
      
        
      
        # method three
      
        from collections import OrderedDict
      
        net = nn.Sequential(OrderedDict([
      
                  ('linear', nn.Linear(num_inputs, 1))
      
                  # ......
      
                ]))
      
        
      
        print(net)
      
        print(net[0])
      
  • 初始化参数

      from torch.nn import init
    
      
    
      init.normal_(net[0].weight, mean=0.0, std=0.01)
    
      init.constant_(net[0].bias, val=0.0)  # or you can use `net[0].bias.data.fill_(0)` to modify it directly
    
  • 损失函数:nn.MSEloss()均方损失函数

  • 优化函数

import torch.optim as optim



optimizer = optim.SGD(net.parameters(), lr=0.03)   # built-in random gradient descent function

print(optimizer)  # function prototype: `torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)

-训练

num_epochs = 3

for epoch in range(1, num_epochs + 1):

    for X, y in data_iter:

        output = net(X)

        l = loss(output, y.view(-1, 1))

        optimizer.zero_grad() # reset gradient, equal to net.zero_grad()

        l.backward()

        optimizer.step()

    print('epoch %d, loss: %f' % (epoch, l.item()))
  • 前向传播

  • 求损失函数值

  • 清空梯度

  • 反向传播

  • 优化

Softmax和分类模型

softmax函数

某一类的概率=某一类预测值的指数/所有预测值指数的和

交叉熵损失函数

torchvidion.datasets

root(string)– 数据集的根目录,其中存放processed/training.pt和processed/test.pt文件。

  • train(bool, 可选)– 如果设置为True,从training.pt创建数据集,否则从test.pt创建。

  • download(bool, 可选)– 如果设置为True,从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据,不会再次下载。

  • transform(可被调用 , 可选)– 一种函数或变换,输入PIL图片,返回变换之后的数据。如:transforms.RandomCrop。

  • target_transform(可被调用 , 可选)– 一种函数或变换,输入目标,进行变换。

多维tensor按维度操作

X = torch.tensor([[1, 2, 3], [4, 5, 6]])

print(X.sum(dim=0, keepdim=True))  # dim为0,按照相同的列求和,并在结果中保留列特征

print(X.sum(dim=1, keepdim=True))  # dim为1,按照相同的行求和,并在结果中保留行特征

print(X.sum(dim=0, keepdim=False)) # dim为0,按照相同的列求和,不在结果中保留列特征

print(X.sum(dim=1, keepdim=False)) # dim为1,按照相同的行求和,不在结果中保留行特征

tensor([[5, 7, 9]])

tensor([[ 6],
[15]])

tensor([5, 7, 9])

tensor([ 6, 15])

tensor.gather()

tensor.item()

tensor只有一个元素时,直接输出是张量形式,item()输出是值

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!