pytorch模型转换为onnx可视化(使用netron)

 更新时间:2023年05月18日 09:44:53   作者:michaelchengjl  
netron 是一个非常好用的网络结构可视化工具,但是netron对pytorch模型的支持还不成熟,这篇文章主要介绍了pytorch模型转换为onnx,并使用netron可视化,需要的朋友可以参考下

pytorch模型转换为onnx,并使用netron可视化

netron 是一个非常好用的网络结构可视化工具。

但是netron对pytorch模型的支持还不成熟。自己试的效果是生成的模型图没有连线。

目前支持的框架 根据netron的github

目前netron支持:

ONNX (.onnx, .pb, .pbtxt)
Keras (.h5, .keras)
Core ML (.mlmodel)
Caffe (.caffemodel, .prototxt)
Caffe2 (predict_net.pb, predict_net.pbtxt)
Darknet (.cfg)
MXNet (.model, -symbol.json)
ncnn (.param) 
TensorFlow Lite (.tflite)
PaddlePaddle (.zip, model)
TensorFlow.js
CNTK (.model, .cntk)

并且实验性支持:

TorchScript (.pt, .pth)
PyTorch (.pt, .pth)
Torch (.t7)
Arm NN (.armnn)
BigDL (.bigdl, .model) 
Chainer (.npz, .h5)
Deeplearning4j (.zip)
MediaPipe (.pbtxt)
ML.NET (.zip), MNN (.mnn)
OpenVINO (.xml)
scikit-learn (.pkl)
TensorFlow (.pb, .meta, .pbtxt, .ckpt, .index)

Netron supports ONNX, TensorFlow Lite, Caffe, Keras, Darknet, PaddlePaddle, ncnn, MNN, Core ML, RKNN, MXNet, MindSpore Lite, TNN, Barracuda, Tengine, CNTK, TensorFlow.js, Caffe2 and UFF.

Netron has experimental support for PyTorch, TensorFlow, TorchScript, OpenVINO, Torch, Vitis AI, kmodel, Arm NN, BigDL, Chainer, Deeplearning4j, MediaPipe, ML.NET and scikit-learn.

这里就有一个把 .pth 模型转化为 .onnx 模型。

Pytorch模型转onnx

model = resnet18(pretrained=True)
# print(model)
# old_net_path = "resnet18.pth"
new_net_path = "./resnet18.onnx"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 导入模型
net = model.to(device)
# net.load_state_dict(torch.load(old_net_path, map_location=device))
net.eval()
input = torch.randn(1, 3, 224, 224).to(device)  # BCHW  其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致
torch.onnx.export(net, input, new_net_path, verbose=False)

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)

参数:

model(torch.nn.Module)-要被导出的模型
args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。

文件中保存模型结构和权重参数

import torch
torch_model = torch.load("save.pt") # pytorch模型加载
batch_size = 1  #批处理大小
input_shape = (3,244,244)   #输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape)		# 生成张量
export_onnx_file = "test.onnx"					# 目的ONNX文件名
torch.onnx.export(torch_model,
                    x,
                    export_onnx_file,
                    opset_version=10,
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名
                    output_names=["output"],	# 输出名
                    dynamic_axes={"input":{0:"batch_size"},		# 批处理变量
                    "output":{0:"batch_size"}})

dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

文件中只保留模型权重

import torch
torch_model = selfmodel()  					# 由研究员提供python.py文件
batch_size = 1 								# 批处理大小
input_shape = (3, 244, 244) 				# 输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape) 	# 生成张量
export_onnx_file = "test.onnx" 				# 目的ONNX文件名
torch.onnx.export(torch_model,
                    x,
                    export_onnx_file,
                    opset_version=10,
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名
                    output_names=["output"],	# 输出名
                    dynamic_axes={"input":{0:"batch_size"},	# 批处理变量
                    "output":{0:"batch_size"}})

到此这篇关于pytorch模型转换为onnx可视化(使用netron)的文章就介绍到这了,更多相关pytorch模型转onnx可视化内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • pycharm重置设置,恢复默认设置的方法

    pycharm重置设置,恢复默认设置的方法

    今天小编就为大家分享一篇pycharm重置设置,恢复默认设置的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python源码学习之PyType_Type和PyBaseObject_Type详解

    Python源码学习之PyType_Type和PyBaseObject_Type详解

    今天给大家带来的是关于Python源码的相关知识学习,文章围绕着PyType_Type和PyBaseObject_Type展开,文中有非常详细的介绍及代码示例,需要的朋友可以参考下
    2021-06-06
  • python中使用 unittest.TestCase单元测试的用例详解

    python中使用 unittest.TestCase单元测试的用例详解

    python 在unittest.TestCase 中提高了很多断言方法,这篇文章主要介绍了python中使用 unittest.TestCase 进行单元测试的操作方法,需要的朋友可以参考下
    2021-08-08
  • PyTorch的安装与使用示例详解

    PyTorch的安装与使用示例详解

    本文介绍了热门AI框架PyTorch的conda安装方案,与简单的自动微分示例,并顺带讲解了一下PyTorch开源Github仓库中的两个Issue内容,需要的朋友可以参考下
    2024-05-05
  • python numpy和list查询其中某个数的个数及定位方法

    python numpy和list查询其中某个数的个数及定位方法

    今天小编就为大家分享一篇python numpy和list查询其中某个数的个数及定位方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 带你详细了解Python GUI编程框架

    带你详细了解Python GUI编程框架

    今天小编就为大家分享一篇python 实现GUI(图形用户界面)编程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-08-08
  • 基于Python实现生成随机手机号码

    基于Python实现生成随机手机号码

    这篇文章主要介绍了生成随机中国手机号码的Python代码实现,本文提供了基础版本和更精确的版本,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下
    2026-02-02
  • python 虚拟环境详解

    python 虚拟环境详解

    这篇文章主要为大家介绍了python 虚拟环境,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助<BR>
    2021-12-12
  • python创建进程fork用法

    python创建进程fork用法

    这篇文章主要介绍了python创建进程fork用法,实例分析了Python使用fork创建进程的使用方法,需要的朋友可以参考下
    2015-06-06
  • 现代Python编程的四个关键点你知道几个

    现代Python编程的四个关键点你知道几个

    这篇文章主要为大家详细介绍了Python编程的四个关键点,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-02-02

最新评论