pytorch笔记7--批训练

江枫思渺然 提交于 2020-02-06 16:58:28
import torch
import torch.utils.data as Data    #用于小批训练
torch.manual_seed(1)   #为cpu设置随机种子,使多次运行结果一致
# torch.cuda.manual_seed(seed)  #为当前GPU设置随机种子
#torch.cuda.manual_seed_all(seed)  #为所有GPU设置随机种子

Batch_Size = 4
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

#用DataLoader来包装数据,用于批训练(首先将数据转换为torch能识别的Dataset形式)
torch_dataset=Data.TensorDataset(x,y)

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=Batch_Size,
    shuffle=True,   #是否打乱数据
)

for epoch in range(3): #将所有数据训练3次
    for step,(batch_x,batch_y) in enumerate(loader): #每一步loader释放一小批数据

        ...
        print('Epoch:{}, Step:{}, batch x:{}, batch y:{}'.format(epoch,step,batch_x.numpy(),batch_y.numpy()))

结果:

Epoch:0, Step:0, batch x:[ 5.  7. 10.  3.], batch y:[6. 4. 1. 8.]
Epoch:0, Step:1, batch x:[4. 2. 1. 8.], batch y:[ 7.  9. 10.  3.]
Epoch:0, Step:2, batch x:[9. 6.], batch y:[2. 5.]
Epoch:1, Step:0, batch x:[ 4.  6.  7. 10.], batch y:[7. 5. 4. 1.]
Epoch:1, Step:1, batch x:[8. 5. 3. 2.], batch y:[3. 6. 8. 9.]
Epoch:1, Step:2, batch x:[1. 9.], batch y:[10.  2.]
Epoch:2, Step:0, batch x:[4. 2. 5. 6.], batch y:[7. 9. 6. 5.]
Epoch:2, Step:1, batch x:[10.  3.  9.  1.], batch y:[ 1.  8.  2. 10.]
Epoch:2, Step:2, batch x:[8. 7.], batch y:[3. 4.]

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!