# Save on GPU, Load on CPU
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(n_input_features=6)
# train your model...
FILE = "model_gpu_cpu.pth"
"""
Save on GPU, Load on CPU
"""
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), FILE)
device = torch.device("cpu")
model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE, map_location=device))
print("Save on GPU, load on CPU: ", next(model.parameters()).is_cuda)
# Save on GPU, Load on GPU
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), FILE)
model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE))
model.to(device)
# 查看model的参数是否在GPU上
print("Save on GPU, load on GPU: ", next(model.parameters()).is_cuda)
# Save on CPU, Load on GPU
torch.save(model.state_dict(), FILE)
device = torch.device("cuda")
model = Model(n_input_features=6)
model.load_state_dict(torch.load(FILE, map_location="cuda:0"))
model.to(device)
print("Save on CPU, load on GPU: ", next(model.parameters()).is_cuda)
来源:oschina
链接:https://my.oschina.net/u/4228078/blog/4324005