模型GPU和CPU的保存和加载

我与影子孤独终老i 提交于 2020-08-07 07:51:29

# Save on GPU, Load on CPU

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)

# train your model...

FILE = "model_gpu_cpu.pth"
"""
Save on GPU, Load on CPU
"""
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), FILE)

device = torch.device("cpu")
model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE, map_location=device))
print("Save on GPU, load on CPU: ", next(model.parameters()).is_cuda)

# Save on GPU, Load on GPU

device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), FILE)

model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE))
model.to(device)

# 查看model的参数是否在GPU上
print("Save on GPU, load on GPU: ", next(model.parameters()).is_cuda)

# Save on CPU, Load on GPU

torch.save(model.state_dict(), FILE)
device = torch.device("cuda")
model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE, map_location="cuda:0"))
model.to(device)
print("Save on CPU, load on GPU: ", next(model.parameters()).is_cuda)

 

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!