pytorch如何使用训练好的模型预测新数据

 更新时间:2023年06月15日 09:03:45   作者:Xiuxiu_Law  
这篇文章主要介绍了pytorch如何使用训练好的模型预测新数据问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch使用训练好的模型预测新数据

神经网络在进行完训练和测试后,如果达到了较高的正确率的话,我们可以尝试将模型用于预测新数据。

总共需要两大部分:神经网络、预测函数(新图片的加载,传入模型、得出结果)。

完整代码

import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):  # 神经网络部分用你自己的
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)  # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(6272, 128)  # 6272=128*7*7
        self.fc2 = nn.Linear(128, 8)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        self.output = F.log_softmax(x, dim=1)
        out1 = x
        return self.output,out1
def predict():
    model = Net()
    model.load_state_dict(torch.load('test.pt'))
    torch.no_grad()
    imgfile = glob.glob(r"")  # 输入要预测的图片所在路径
    print(len(imgfile), imgfile)
    for i in imgfile:
        imgfile1 = i.replace("\\", "/")
        img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (64, 64))  # 是否需要resize取决于新图片格式与训练时的是否一致
        tran = transforms.ToTensor()
        img = img.reshape((*img.shape, -1))
        img = tran(img)
        img = img.unsqueeze(0)
        outputs, out1 = model(img)  # outputs,out1修改为你的网络的输出
        predicted, index  = torch.max(out1, 1)
        degre = int(index[0])
        list = [0, 45, -45, -90, 90, 135, -135, 180]
        print(predicted, list[degre])
if __name__ == '__main__':
    predict()

神经网络部分复制你在训练时定义的神经网络即可,如果模型保存为字典,则需要

model.load_state_dict(torch.load('test.pt'))

新图片的格式需要与训练测试时的图片格式保持一致,所以需要resize,如果新图片为相同格式略过。

最后的list是你样本类别的list,每一类的索引需要与label保持一致,例如:

list = ['裤子', '套衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '短靴']

结果分析

tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135

前面的张量tensor代表了各个类别的“概率”中最大的那一个,然后根据最大“概率”所在的位置(index)来找到list所对应的类别,然后输出。

pytorch框架--简单模型预测

模型预测示例

使用训练好的模型进行预测

import torchvision
from model import Tudui
import torch
from PIL import Image
# 读取图像
img = Image.open("./data/train/Dog/9.jpg")
# 数据预处理
# 缩放
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(img)
print(image.shape)
# 根据保存方式加载
model = torch.load("tudui_99.pth", map_location=torch.device('cpu'))
# 注意维度转换,单张图片
image1 = torch.reshape(image, (1, 3, 32, 32))
# 测试开关
model.eval()
# 节约性能
with torch.no_grad():
    output = model(image1)
print(output)
# print(output.argmax(1))
# 定义类别对应字典
dist = {0: "飞机", 1: "汽车", 2: "鸟", 3: "猫", 4: "鹿", 5: "狗", 6: "青蛙", 7: "马", 8: "船", 9: "卡车"}
# 转numpy格式,列表内取第一个
a = dist[output.argmax(1).numpy()[0]]
img.show()
print(a)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Label Propagation算法原理示例解析

    Label Propagation算法原理示例解析

    这篇文章主要为大家介绍了Label Propagation算法原理示例解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-02-02
  • python3实现倒计时效果

    python3实现倒计时效果

    这篇文章主要为大家详细介绍了python3实现倒计时效果,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-08-08
  • Python实现打印彩色字符串的方法详解

    Python实现打印彩色字符串的方法详解

    print 也许是我们在使用 Python 的时候用的最多的一种操作,但是经常发现很多人可以打印彩色文本,这种操作是怎么得到的呢?本文就来为大家详细讲讲
    2022-08-08
  • Python之tkinter文字区域Text使用及说明

    Python之tkinter文字区域Text使用及说明

    这篇文章主要介绍了Python之tkinter文字区域Text使用及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-05-05
  • 新手学python应该下哪个版本

    新手学python应该下哪个版本

    在本篇内容中小编给大家整理的是关于新手学python应该下版本的相关知识点,需要的朋友们可以参考学习下。
    2020-06-06
  • Django中间件拦截未登录url实例详解

    Django中间件拦截未登录url实例详解

    在本篇文章里小编给各位整理了关于Django中间件拦截未登录url的实例内容以及相关知识点,有需要的朋友们可以学习下。
    2019-09-09
  • python+pytest接口自动化之token关联登录的实现

    python+pytest接口自动化之token关联登录的实现

    公司某管理后台系统,登录后返回token,接着去请求其他接口时请求头中都需要加上这个token,否则提示请先登录,今天通过本文给大家介绍下python+pytest接口自动化之token关联登录的实现,感兴趣的朋友一起看看吧
    2022-04-04
  • 使用TensorFlow创建生成式对抗网络GAN案例

    使用TensorFlow创建生成式对抗网络GAN案例

    这篇文章主要为大家介绍了使用TensorFlow创建生成式对抗网络GAN案例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • Python 获取中文字拼音首个字母的方法

    Python 获取中文字拼音首个字母的方法

    今天小编就为大家分享一篇Python 获取中文字拼音首个字母的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 基于Python实现人机PK小游戏

    基于Python实现人机PK小游戏

    这篇文章主要为大家详细介绍了如何基于Python实现人机PK小游戏,简单来说,就是随机生成玩家和敌人的属性,同时互相攻击,直至一方血量小于零,感兴趣的小伙伴可以学习一下
    2023-06-06

最新评论