pytorch模型转onnx模型的方法详解

 更新时间:2022年08月30日 11:55:40   作者:挣扎的笨鸟  
很多时候有pytorch模型转onnx模型的必要,比如用tensorRT加速的时候,下面这篇文章主要给大家介绍了关于pytorch模型转onnx模型的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下

学习目标

1.掌握pytorch模型转换到onnx模型

2.顺利运行onnx模型

3.比对onnx模型和pytorch模型的输出结果

学习大纲

  • pytorch模型转换onnx模型
  • 运行onnx模型
  • onnx模型输出与pytorch模型比对

学习内容

前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装

1 . pytorch 转 onnx

pytorch 转 onnx 只需要一个函数 torch.onnx.export

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

参数说明:

  • model——需要导出的pytorch模型
  • args——模型的输入参数,满足输入层的shape正确即可。
  • path——输出的onnx模型的位置。例如‘yolov5.onnx’。
  • export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
  • verbose——是否打印模型转换信息。default=False。
  • input_names——输入节点名称。default=None。
  • output_names——输出节点名称。default=None。
  • do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。
  • dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。
    格式如下 :
    1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
    2)仅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}
    3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}
  • opset_version——opset的版本,低版本不支持upsample等操作。
import torch
import torch.nn
import onnx

model = torch.load('best.pt')
model.eval()

input_names = ['input']
output_names = ['output']

x = torch.randn(1,3,32,32,requires_grad=True)

torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

2 . 运行onnx模型

检查onnx模型,并使用onnxruntime运行。

import onnx
import onnxruntime as ort

model = onnx.load('best.onnx')
onnx.checker.check_model(model)

session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32)  # 注意输入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)


outputs = session.run(None,input = { 'input' : x })

参数说明:

  • output_names: default=None
    用来指定输出哪些,以及顺序
    若为None,则按序输出所有的output,即返回[output_0,output_1]
    若为[‘output_1’,‘output_0’],则返回[output_1,output_0]
    若为[‘output_0’],则仅返回[output_0:tensor]
  • input:dict
    可以通过session.get_inputs().name获得名称
    其中key值要求与torch.onnx.export中设定的一致

3.onnx模型输出与pytorch模型比对

import numpy as np
np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)

如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。

此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。

内容参考:https://zhuanlan.zhihu.com/p/422290231

总结

到此这篇关于pytorch模型转onnx模型的文章就介绍到这了,更多相关pytorch模型转onnx模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python IO文件管理的具体使用

    Python IO文件管理的具体使用

    我们可以使用python来操作文件,比如读取文件内容、写入新的内容等,本文主要介绍了Python IO文件管理的具体使用,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-03-03
  • python框架Django实战商城项目之工程搭建过程图文详解

    python框架Django实战商城项目之工程搭建过程图文详解

    这篇文章主要介绍了python框架Django实战商城项目之工程搭建过程,这个项目很像京东商城,项目开发采用前后端不分离的模式,本文通过图文并茂的形式给大家介绍的非常详细,需要的朋友可以参考下
    2020-03-03
  • 使用python实现tcp自动重连

    使用python实现tcp自动重连

    下面小编就为大家带来一篇使用python实现tcp自动重连实现方法。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。
    2017-07-07
  • 用tensorflow实现弹性网络回归算法

    用tensorflow实现弹性网络回归算法

    这篇文章主要介绍了用tensorflow实现弹性网络回归算法
    2018-01-01
  • 使用python Telnet远程登录执行程序的方法

    使用python Telnet远程登录执行程序的方法

    今天小编就为大家分享一篇使用python Telnet远程登录执行程序的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python prettytable模块应用详解

    Python prettytable模块应用详解

    PrettyTable 是python中的一个第三方库,可用来生成美观的ASCII格式的表格,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧
    2022-09-09
  • python实现鸢尾花三种聚类算法(K-means,AGNES,DBScan)

    python实现鸢尾花三种聚类算法(K-means,AGNES,DBScan)

    这篇文章主要介绍了python实现鸢尾花三种聚类算法(K-means,AGNES,DBScan),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-06-06
  • python使用正则表达式替换匹配成功的组

    python使用正则表达式替换匹配成功的组

    正则表达式,又称正规表示式、正规表示法、正规表达式、规则表达式、常规表示法。这篇文章主要介绍了python里使用正则表达式来替换匹配成功的组,需要的朋友可以参考下
    2017-11-11
  • win10下安装Anaconda的教程(python环境+jupyter_notebook)

    win10下安装Anaconda的教程(python环境+jupyter_notebook)

    Anaconda指的是一个开源的Python发行版本,其包含了conda、Python等180多个科学包及其依赖项。这篇文章主要介绍了win10下安装Anaconda(python环境+jupyter_notebook),需要的朋友可以参考下
    2019-10-10
  • python二分查找搜索算法的多种实现方法

    python二分查找搜索算法的多种实现方法

    二分查找,也称折半查找,是一种效率较高的查找方法,本文主要介绍了python二分查找搜索算法的多种实现方法,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03

最新评论