在训练完网络后,我们需要将需要识别的图片输入模型中并输出结果,下面是具体代码。
import torchfrom PIL import Imageimport torchvision.transforms as transdef testImage(): net = torch.load("models/net.pth") img = Image.open("images/car.jpg") transform = trans.Compose([ trans.Resize(32), trans.CenterCrop(32), trans.ToTensor(), trans.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)) ]) img = transform(img).unsqueeze(0) output = net(img.cuda()) classes = ('plane','car', 'bird','cat','deer','dog','frog','horse','ship','truck') index = output.argmax(1) print("预测结果是:{0}".format(classes[index]))testImage()