Pytorch通过保存为ONNX模型转TensorRT5的实现

 更新时间:2020年05月25日 11:23:24   作者:小关学长  
这篇文章主要介绍了Pytorch通过保存为ONNX模型转TensorRT5的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

1 Pytorch以ONNX方式保存模型

 def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

 def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

 def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python with语句上下文管理器两种实现方法分析

    Python with语句上下文管理器两种实现方法分析

    这篇文章主要介绍了Python with语句上下文管理器两种实现方法,结合实例形式较为详细的分析了Python上下文管理器的相关概念、功能、使用方法及相关操作注意事项,需要的朋友可以参考下
    2018-02-02
  • Selenium多窗口切换解决方案

    Selenium多窗口切换解决方案

    本文主要介绍了Selenium多窗口切换解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-07-07
  • 解读keras中的正则化(regularization)问题

    解读keras中的正则化(regularization)问题

    这篇文章主要介绍了解读keras中的正则化(regularization)问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • 使用Python实现摇号系统的详细步骤

    使用Python实现摇号系统的详细步骤

    这篇文章主要介绍了如何使用Python构建一个简单的摇号系统,包括需求分析、技术栈、实现步骤和完整代码示例,该系统能够从用户输入的参与者名单中随机抽取指定数量的中奖者,并将结果展示给用户以及记录到日志文件中,需要的朋友可以参考下
    2024-11-11
  • 解决python3爬虫无法显示中文的问题

    解决python3爬虫无法显示中文的问题

    下面小编就为大家分享一篇解决python3爬虫无法显示中文的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python爬取之json、pickle与shelve库的深入讲解

    python爬取之json、pickle与shelve库的深入讲解

    这篇文章主要给大家介绍了关于python爬取之json、pickle与shelve库的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • PyCharm 2019.3发布增加了新功能一览

    PyCharm 2019.3发布增加了新功能一览

    这篇文章主要介绍了PyCharm 2019.3发布,增加了新功能一览,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • keras model.fit 解决validation_spilt=num 的问题

    keras model.fit 解决validation_spilt=num 的问题

    这篇文章主要介绍了keras model.fit 解决validation_spilt=num 的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python figure参数及subplot子图绘制代码

    Python figure参数及subplot子图绘制代码

    这篇文章主要介绍了Python figure参数及subplot子图绘制代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • python matplotlib绘画十一种常见数据分析图

    python matplotlib绘画十一种常见数据分析图

    这篇文章主要介绍了python matplotlib绘画十一种常见数据分析图,文章主要绘制折线图、散点图、直方图、饼图等需要的小伙伴可以参考一下文章具体内容
    2022-06-06

最新评论