使用Pytorch导出自定义ONNX算子的示例代码

 更新时间:2024年03月08日 12:00:30   作者:太阳花的小绿豆  
这篇文章主要介绍了使用Pytorch导出自定义ONNX算子的示例代码,下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出,需要的朋友可以参考下

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomAffineGrid(Function):
    @staticmethod
    def forward(ctx, theta: torch.Tensor, size: torch.Tensor):
        grid = F.affine_grid(theta=theta, size=size.cpu().tolist())
        return grid
    @staticmethod
    def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):
        return g.op("AffineGrid", theta, size)
class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):
        grid = CustomAffineGrid.apply(theta, size)
        x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")
        return x
def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        theta = torch.randn(1, 2, 3)
        size = torch.as_tensor([1, 3, 512, 512])
        torch.onnx.export(model=custum_model,
                          args=(x, theta, size),
                          f="custom.onnx",
                          input_names=["input0_x", "input1_theta", "input2_size"],
                          output_names=["output"],
                          dynamic_axes={"input0_x": {2: "h0", 3: "w0"},
                                        "output": {2: "h1", 3: "w1"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
    main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomRot90AndScale(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90
        x *= 1.2
        return x
    @staticmethod
    def symbolic(g: torch.Graph, x: torch.Tensor):
        return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")
class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        return CustomRot90AndScale.apply(x)
def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        torch.onnx.export(model=custum_model,
                          args=(x,),
                          f="custom_rot90.onnx",
                          input_names=["input"],
                          output_names=["output"],
                          dynamic_axes={"input": {2: "h0", 3: "w0"},
                                        "output": {2: "w0", 3: "h0"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
    main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。

到此这篇关于使用Pytorch导出自定义ONNX算子的文章就介绍到这了,更多相关使用Pytorch导出自定义ONNX算子内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 一文深入详解Python循环引用与垃圾回收

    一文深入详解Python循环引用与垃圾回收

    在 Python 中,内存管理是一个至关重要的主题,特别是在处理长时间运行的服务和大量数据时,内存泄漏和资源管理不当往往是导致服务性能下降或崩溃的根源之一,本文将深入探讨什么情况下会出现循环引用,GC(垃圾回收)是如何处理它的,需要的朋友可以参考下
    2026-05-05
  • Python基于Pymssql模块实现连接SQL Server数据库的方法详解

    Python基于Pymssql模块实现连接SQL Server数据库的方法详解

    这篇文章主要介绍了Python基于Pymssql模块实现连接SQL Server数据库的方法,较为详细的分析了pymssql模块的下载、安装及连接、操作SQL Server数据库的相关实现技巧,需要的朋友可以参考下
    2017-07-07
  • Python解析Excel图表Chart的信息实战指南

    Python解析Excel图表Chart的信息实战指南

    在数据分析与报表自动化场景中,Excel图表往往承载着关键业务信息,本文将基于OpenXML规范,通过将.xlsx文件视为ZIP压缩包,直接解析 xl/charts/chart*.xml,实现了对 Excel 图表元数据的精准提取,感兴趣的小伙伴可以了解下
    2026-01-01
  • 使用python采集Excel表中某一格数据

    使用python采集Excel表中某一格数据

    这篇文章主要介绍了使用python采集Excel表中某一格数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Python 模块EasyGui详细介绍

    Python 模块EasyGui详细介绍

    这篇文章主要介绍了Python 模块EasyGui详细介绍的相关资料,需要的朋友可以参考下
    2017-02-02
  • 使用 Python 列出串口的实现方法

    使用 Python 列出串口的实现方法

    有时在编程时,我们需要获取有关系统中可用通信端口的信息, 我们将讨论如何使用 Python 来做到这一点,将讨论使用串口或 com 端口的通信, 我们将深入探索 Python 包,以帮助我们获得系统的可用通信端口,感兴趣的朋友一起看看吧
    2023-08-08
  • Python tkinter实现桌面软件流程详解

    Python tkinter实现桌面软件流程详解

    这篇文章主要介绍了Python tkinter做一个好用的桌面软件,100%你会爱上它,文中的示例代码讲解详细,快跟小编一起动手试一试吧
    2022-10-10
  • 从np.random.normal()到正态分布的拟合操作

    从np.random.normal()到正态分布的拟合操作

    这篇文章主要介绍了从np.random.normal()到正态分布的拟合操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • python服务器与android客户端socket通信实例

    python服务器与android客户端socket通信实例

    这篇文章主要介绍了python服务器与android客户端socket通信的实现方法,实例形式详细讲述了Python的服务器端实现原理与方法,以及对应的Android客户端实现方法,需要的朋友可以参考下
    2014-11-11
  • 一句话理解pyside6的信号和槽机制

    一句话理解pyside6的信号和槽机制

    本文介绍了PySide6信号与槽机制的核心概念和使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2026-01-01

最新评论