如何计算 tensorflow 和 pytorch 模型的浮点运算数

 更新时间:2022年11月26日 16:57:10   作者:浩哥依然  
FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs,需要的朋友可以参考下

本文主要讨论如何计算 tensorflow 和 pytorch 模型的 FLOPs。如有表述不当之处欢迎批评指正。欢迎任何形式的转载,但请务必注明出处。

1. 引言

FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。

2. 模型结构

为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。

表 1 模型结构及主要参数

LayerschannelsKernelsStridesUnitsActivation
Conv2D32(4,4)(1,2)\relu
GRU\\\96\
Dense\\\256sigmoid

用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:

from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model

def test_model_tf(Input_shape):
    # shape: [B, C, T, F]
    main_input = Input(batch_shape=Input_shape, name='main_inputs')
    
    conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
    
    # shape: [B, T, FC]
    gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
    gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
    
    output = Dense(256, activation='sigmoid', name='output')(gru)
    
    model = Model(inputs=[main_input], outputs=[output])
    
    return model

用 pytorch 实现该模型的代码为:

import torch
import torch.nn as nn

class test_model_torch(nn.Module):
    def __init__(self):
        super(test_model_torch, self).__init__()

        self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
        self.relu = nn.ReLU()

        self.gru = nn.GRU(input_size=4064, hidden_size=96)

        self.fc = nn.Linear(96, 256)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        # shape: [B, C, T, F]
        out = self.conv2d(inputs)
        out = self.relu(out)
        
        # shape: [B, T, FC]
        batch, channel, frame, freq = out.size()
        out = torch.reshape(out, (batch, frame, freq*channel))
        out, _ = self.gru(out)
        
        out = self.fc(out)
        out = self.sigmoid(out)

        return out

3. 计算模型的 FLOPs

本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。

3.1. tensorflow 1.12.0

在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:

import tensorflow as tf
import tensorflow.keras.backend as K

def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)
 
    return flops.total_float_ops

if __name__ == "__main__":
    x = K.random_normal(shape=(1, 1, 100, 256))
    model = test_model_tf(x.shape)
    print('FLOPs of tensorflow 1.12.0:', get_flops(model))

3.2. tensorflow 2.3.1

在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :

import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()

def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()

    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)
 
    return flops.total_float_ops

if __name__ == "__main__":
    x = K.random_normal(shape=(1, 1, 100, 256))
    model = test_model_tf(x.shape)
    print('FLOPs of tensorflow 2.3.1:', get_flops(model))

3.3. pytorch 1.10.1+cu102

在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):

import thop

x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)

需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。

3.4. 结果对比

三者计算出的 FLOPs 分别为:
tensorflow 1.12.0:

tensorflow 2.3.1:

pytorch 1.10.1:


可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。如读者知道其中详情,还请不吝赐教。

4. 总结

本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflow 与 pytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs。

到此这篇关于计算 tensorflow 和 pytorch 模型的浮点运算数的文章就介绍到这了,更多相关tensorflow 和 pytorch浮点运算数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python求解三角形第三边长实例

    python求解三角形第三边长实例

    这篇文章主要介绍了python求解三角形第三边长实例,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • Python3使用xlrd、xlwt处理Excel方法数据

    Python3使用xlrd、xlwt处理Excel方法数据

    这篇文章主要介绍了Python3使用xlrd、xlwt处理Excel方法数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • python自动打开浏览器下载zip并提取内容写入excel

    python自动打开浏览器下载zip并提取内容写入excel

    这篇文章主要给大家介绍了关于python自动打开浏览器下载zip并提取内容写入excel的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Python实现http服务器(http.server模块传参 接收参数)实例

    Python实现http服务器(http.server模块传参 接收参数)实例

    这篇文章主要为大家介绍了Python实现http服务器(http.server模块传参 接收参数)实例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-11-11
  • Python 实现PS滤镜的旋涡特效

    Python 实现PS滤镜的旋涡特效

    这篇文章主要介绍了Python 实现 PS 滤镜的旋涡特效,帮助大家更好的利用python处理图片,感兴趣的朋友可以了解下
    2020-12-12
  • 使用Matplotlib 绘制精美的数学图形例子

    使用Matplotlib 绘制精美的数学图形例子

    今天小编就为大家分享一篇使用Matplotlib 绘制精美的数学图形例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • 初探TensorFLow从文件读取图片的四种方式

    初探TensorFLow从文件读取图片的四种方式

    本篇文章主要介绍了初探TensorFLow从文件读取图片的四种方式,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-02-02
  • 对python sklearn one-hot编码详解

    对python sklearn one-hot编码详解

    今天小编就为大家分享一篇对python sklearn one-hot编码详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python内置函数之eval函数详解

    python内置函数之eval函数详解

    这篇文章主要为大家介绍了python内置函数之eval函数,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-01-01
  • Python单链表原理与实现方法详解

    Python单链表原理与实现方法详解

    这篇文章主要介绍了Python单链表原理与实现方法,结合实例形式详细分析了Python单链表的具体概念、原理、实现方法与操作注意事项,需要的朋友可以参考下
    2020-02-02

最新评论