Why doesn't my simple pytorch network work on GPU device?

后端 未结 2 378
鱼传尺愫
鱼传尺愫 2020-12-20 23:45

I built a simple network from a tutorial and I got this error:

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.Flo

2条回答
  •  半阙折子戏
    2020-12-21 00:03

    import torch
    import numpy as np
    
    x = torch.tensor(np.array(1), device='cuda:0')
    
    print(x.device)  # Prints `cpu`
    
    x = torch.tensor(1, device='cuda:0')
    
    print(x.device)  # Prints `cuda:0`
    

    Now the tensor resides on GPU

提交回复
热议问题