python pytorch模型转onnx模型的全过程(多输入+动态维度)

 更新时间:2024年03月21日 11:51:25   作者:暗号9  
这篇文章主要介绍了python pytorch模型转onnx模型的全过程(多输入+动态维度),本文给大家记录记录了pt文件转onnx全过程,简单的修改即可应用,结合实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧

(多输入+动态维度)整理的自定义神经网络pt转onnx过程的python代码,记录了pt文件转onnx全过程,简单的修改即可应用。

pt文件转onnx步骤 

1、编写预处理代码

预处理代码 与torch模型的预处理代码一样

def preprocess(img):
	img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
	img = np.expand_dims(img, 0)
	sh_im = img.shape
	if sh_im[2]%2==1:
    	img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
	if sh_im[3]%2==1:
    	img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
	img = normalize(img)
	img = torch.Tensor(img)
	return img

2、用onnxruntime导出onnx

def export_onnx(net, model_path, img, nsigma, onnx_outPath):
	nsigma /= 255.
	if torch.cuda.is_available():
    	state_dict = torch.load(model_path)
    	model = net.cuda()
    	dtype = torch.cuda.FloatTensor
    else:
    	state_dict = torch.load(model_path, map_location='cpu')
    	state_dict = remove_dataparallel_wrapper(state_dict)
    	model = net
    	dtype = torch.FloatTensor
	img = Variable(img.type(dtype))
	nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
	# 我这里预训练权重中参数名字与网络名字不同
	# 相同的话可直接load_state_dict(state_dict)
	new_state_dict = {}
	for k, v in state_dict.items():
    	new_state_dict[k[7:]] = v
	model.load_state_dict(new_state_dict)
	# 设置onnx的输入输出列表,多输入多输出就设置多个
	input_list = ['input', 'nsigma']
	output_list = ['output']
	# onnx模型导出
	# dynamic_axes为动态维度,如果自己的输入输出是维度变化的建议设置,否则只能输入固定维度的tensor
	torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
		 				input_names=input_list, output_names=output_list,
		 				dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
		 				'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}}) 

导出结果

3、对导出的模型进行检查

此处为检查onnx模型节点,后面如果onnx算子不支持转engine时,方便定位节点,找到不支持的算子进行修改

def check_onnx(onnx_model_path):
	model = onnx.load(onnx_model_path)
	onnx.checker.check_model((model))
	print(onnx.helper.printable_graph(model.graph))

下面贴出输出结果


netron可视化

4、推理onnx模型,查看输出是否一致

    def run_onnx(onnx_model_path, test_img, nsigma):
		nsigma /= 255.
		with torch.no_grad:
    	# 这里默认是cuda推理torch.cuda.FloatTensor
    	img = Variable(test_img.type(torch.cuda.FloatTensor))
    	nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
    	# 设置GPU推理
    	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    	providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
    	# 通过创建onnxruntime session来运行onnx模型
    	ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
    	output = ort_session.run(output_names=['output'],
                             	input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
                                'nsigma':  np.array(nsigma.cpu(), dtype=np.float32)})
		return output

5、对onnx模型的输出进行处理,显示cv图像

def postprocess(img, img_noise_estime):
    out = torch.clamp(img-img_noise_estime, 0., 1.)
    outimg = variable_to_cv2_image(out)
    cv2.imshow(outimg)

6、编辑主函数进行测试

def main():
    ##############################
    #
    #        onnx模型导出
    #
    ##############################
    # pt权重路径:自己的路径 + mypt.pt
    model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
    # export onnx模型时输入进去数据,用于onnx记录网络的计算过程
    export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
    # onnx模型导出的路径
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    # 实例化自己的网络模型并设置输入参数
    net = FFDNet(num_input_channels=3)
    nsigma = 25
    # onnx 导出
    img = cv2.imread(export_feed_path)
    input = preprocess(img)
    export_onnx(net, model_path, input, nsigma, onnx_outpath)
    print("export success!")
    ##############################
    #
    #        检查onnx模型
    #
    ##############################
    check_onnx(onnx_outpath)
    # netron可视化网络,可视化用节点记录的网络推理流程
    netron.start(onnx_outpath)
    ##############################
    #
    #        运行onnx模型
    #
    ##############################
    # 此处过程是数据预处理 ---> 调用run_onnx函数 ---> 对模型输出后处理
    # 具体代码就不再重复了

#完整代码

import time
import netron
import cv2
import torch
import onnx
import numpy as np
from torch.autograd import Variable
import onnxruntime as ort
from models import FFDNet
from utils import remove_dataparallel_wrapper, normalize, variable_to_cv2_image
# 此处为预处理代码 与torch模型的预处理代码一样
def preprocess(img):
    img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    sh_im = img.shape
    if sh_im[2]%2==1:
        img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
    if sh_im[3]%2==1:
        img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
    img = normalize(img)
    img = torch.Tensor(img)
    return img
# 此处为onnx模型导出的代码,包括torch模型的pt权重加载,onnx模型的导出
def export_onnx(net, model_path, img, nsigma, onnx_outPath):
    nsigma /= 255.
    if torch.cuda.is_available():
        state_dict = torch.load(model_path)
        model = net.cuda()
        dtype = torch.cuda.FloatTensor
    else:
        state_dict = torch.load(model_path, map_location='cpu')
        state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
        dtype = torch.FloatTensor
    img = Variable(img.type(dtype))
    nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
    # 我这里预训练权重中参数名字与网络名字不同
    # 相同的话可直接load_state_dict(state_dict)
    new_state_dict = {}
    for k, v in state_dict.items():
        new_state_dict[k[7:]] = v
    model.load_state_dict(new_state_dict)
    # 设置onnx的输入输出列表,多输入多输出就设置多个
    input_list = ['input', 'nsigma']
    output_list = ['output']
    # onnx模型导出
    # dynamic_axes为动态维度,如果自己的输入输出是维度变化的建议设置,否则只能输入固定维度的tensor
    torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
                      input_names=input_list, output_names=output_list,
                      dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
                                    'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}})
# 此处为检查onnx模型节点,后面如果onnx算子不支持转engine时,方便定位节点,找到不支持的算子进行修改
def check_onnx(onnx_model_path):
    model = onnx.load(onnx_model_path)
    onnx.checker.check_model((model))
    print(onnx.helper.printable_graph(model.graph))
# 此处为推理onnx模型的代码,检查输出是否跟torch模型相同
def run_onnx(onnx_model_path, test_img, nsigma):
    nsigma /= 255.
    with torch.no_grad:
        # 这里默认是cuda推理torch.cuda.FloatTensor
        img = Variable(test_img.type(torch.cuda.FloatTensor))
        nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
        # 设置GPU推理
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
        # 通过创建onnxruntime session来运行onnx模型
        ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
        output = ort_session.run(output_names=['output'],
                                 input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
                                             'nsigma':  np.array(nsigma.cpu(), dtype=np.float32)})
    return output
# 此处是后处理代码,将onnx模型的输出处理成可显示cv图像
# 与torch模型的后处理一样
def postprocess(img, img_noise_estime):
    out = torch.clamp(img-img_noise_estime, 0., 1.)
    outimg = variable_to_cv2_image(out)
    cv2.imshow(outimg)
def main():
    ##############################
    #
    #        onnx模型导出
    #
    ##############################
    # pt权重路径:自己的路径 + mypt.pt
    model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
    # export onnx模型时输入进去数据,用于onnx记录网络的计算过程
    export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
    # onnx模型导出的路径
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    # 实例化自己的网络模型并设置输入参数
    net = FFDNet(num_input_channels=3)
    nsigma = 25
    # onnx 导出
    img = cv2.imread(export_feed_path)
    input = preprocess(img)
    export_onnx(net, model_path, input, nsigma, onnx_outpath)
    print("export success!")
    ##############################
    #
    #        检查onnx模型
    #
    ##############################
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    check_onnx(onnx_outpath)
    # netron可视化网络,可视化用节点记录的网络推理流程
    netron.start(onnx_outpath)
    ##############################
    #
    #        运行onnx模型
    #
    ##############################
    # 此处过程是数据预处理 ---> 调用run_onnx函数 ---> 对模型输出后处理
    # 具体代码就不再重复了
if __name__ == '__main__':
    main()

到此这篇关于python pytorch模型转onnx模型(多输入+动态维度)的文章就介绍到这了,更多相关python pytorch模型转onnx模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python:按行读入,排序然后输出的方法

    python:按行读入,排序然后输出的方法

    今天小编就为大家分享一篇python:按行读入,排序然后输出的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python列表删除重复元素与图像相似度判断及删除实例代码

    Python列表删除重复元素与图像相似度判断及删除实例代码

    这篇文章主要给大家介绍了关于Python列表删除重复元素与图像相似度判断及删除的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-05-05
  • 详解django2中关于时间处理策略

    详解django2中关于时间处理策略

    这篇文章主要介绍了详解django2中关于时间处理策略,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-03-03
  • 如何使用Python程序完成描述性统计分析需求

    如何使用Python程序完成描述性统计分析需求

    这篇文章主要介绍了如何使用Python程序完成描述性统计分析需求,运用制表和分类,图形以及计算概括性数据来描述数据特征的各项活动,需要的朋友可以参考下
    2023-03-03
  • Python更改pip镜像源的方法示例

    Python更改pip镜像源的方法示例

    这篇文章主要介绍了Python更改pip镜像源的方法示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • Python中实现从目录中过滤出指定文件类型的文件

    Python中实现从目录中过滤出指定文件类型的文件

    这篇文章主要介绍了Python中实现从目录中过滤出指定文件类型的文件,本文是一篇学笔记,实例相对简单,需要的朋友可以参考下
    2015-02-02
  • 关于np.meshgrid函数中的indexing参数问题

    关于np.meshgrid函数中的indexing参数问题

    Meshgrid函数在二维与三维空间中用于生成坐标网格,便于进行图像处理和空间数据分析,二维情况下,默认使用笛卡尔坐标系,而三维meshgrid则涉及不同的坐标轴取法,在三维情况下,可能会出现坐标轴排列序混乱
    2024-09-09
  • Python 日期与时间转换的方法

    Python 日期与时间转换的方法

    这篇文章主要介绍了Python 日期与时间转换的方法,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-08-08
  • Python3爬虫中识别图形验证码的实例讲解

    Python3爬虫中识别图形验证码的实例讲解

    在本篇内容里小编给大家分享的是关于Python3爬虫中识别图形验证码的实例讲解内容,需要的朋友们可以学习参考下。
    2020-07-07
  • Python3.7安装keras和TensorFlow的教程图解

    Python3.7安装keras和TensorFlow的教程图解

    这篇文章主要介绍了Python3.7安装keras和TensorFlow经验,本文图文并茂给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-10-10

最新评论