pytorch 计算Parameter和FLOP的操作

 更新时间:2021年03月04日 14:57:31   作者:落地生根1314  
这篇文章主要介绍了pytorch 计算Parameter和FLOP的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

深度学习中,模型训练完后,查看模型的参数量和浮点计算量,在此记录下:

1 THOP

在pytorch中有现成的包thop用于计算参数数量和FLOP,首先安装thop:

pip install thop

注意安装thop时可能出现如下错误:

解决方法:

pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git # 下载源码安装

使用方法如下:

from torchvision.models import resnet50 # 引入ResNet50模型
from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224)) # profile(模型,输入数据)

对于自己构建的函数也一样,例如shuffleNetV2

  from thop import profile
  from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
  import torch 
  
  model_shuffle = shufflenetv2(width_mult=0.5)
  model = torch.nn.DataParallel(model_shuffle)  # 调用shufflenet2 模型,该模型为自己定义的
  flop, para = profile(model, input_size=(1, 3, 224, 224),) 
  print("%.2fM" % (flop/1e6), "%.2fM" % (para/1e6))

更多细节,可参考thop GitHub链接: https://github.com/Lyken17/pytorch-OpCounter

2 计算参数

pytorch本身带有计算参数的方法

  from thop import profile
  from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
  import torch 
  
  model_shuffle = shufflenetv2(width_mult=0.5)
  model = torch.nn.DataParallel(model_shuffle)
  total = sum([param.nelement() for param in model.parameters()])
  print("Number of parameter: %.2fM" % (total / 1e6))

补充:pytorch: 计算网络模型的计算量(FLOPs)和参数量(Params)

计算量:

FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。

参数量:

Params,是指网络模型中需要训练的参数总数。

第一步:安装模块(thop)

pip install thop

第二步:计算

import torch
from thop import profile
net = Model() # 定义好的网络模型
input = torch.randn(1, 3, 112, 112)
flops, params = profile(net, (inputs,))
print('flops: ', flops, 'params: ', params)

注意:

输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数

profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • 基于Python实现视频的人脸融合功能

    基于Python实现视频的人脸融合功能

    这篇文章主要介绍了用Python快速实现视频的人脸融合功能,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-06-06
  • Django中Middleware中的函数详解

    Django中Middleware中的函数详解

    这篇文章主要介绍了Django中Middleware中的函数详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • 用Python一键搭建Http服务器的方法

    用Python一键搭建Http服务器的方法

    今天小编就为大家分享一篇用Python一键搭建Http服务器的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 使用Tensorflow将自己的数据分割成batch训练实例

    使用Tensorflow将自己的数据分割成batch训练实例

    今天小编就为大家分享一篇使用Tensorflow将自己的数据分割成batch训练实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 如何使用PyCharm将代码上传到GitHub上(图文详解)

    如何使用PyCharm将代码上传到GitHub上(图文详解)

    这篇文章主要介绍了如何使用PyCharm将代码上传到GitHub上(图文详解),文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • 用代码帮你了解Python基础(3)

    用代码帮你了解Python基础(3)

    这篇文章主要用代码帮你了解Python基础,使用循环,字典和集合的示例代码,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-01-01
  • Python打印scrapy蜘蛛抓取树结构的方法

    Python打印scrapy蜘蛛抓取树结构的方法

    这篇文章主要介绍了Python打印scrapy蜘蛛抓取树结构的方法,实例分析了打印scrapy蜘蛛抓取树结构的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • pytorch DataLoader的num_workers参数与设置大小详解

    pytorch DataLoader的num_workers参数与设置大小详解

    这篇文章主要介绍了pytorch DataLoader的num_workers参数与设置大小详解,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-05-05
  • Python循环语句之while循环和for循环详解

    Python循环语句之while循环和for循环详解

    在Python中,循环语句用于重复执行一段代码,直到满足某个条件为止,在Python中,有两种主要的循环语句:for循环和while循环,本文就来给大家介绍一下这两个循环的用法,需要的朋友可以参考下
    2023-08-08
  • python 绘制场景热力图的示例

    python 绘制场景热力图的示例

    这篇文章主要介绍了python 绘制场景热力图的示例,帮助大家更好的利用python绘制图像,感兴趣的朋友可以了解下
    2020-09-09

最新评论