pytorch中nn.Flatten()函数详解及示例

 更新时间:2023年01月06日 16:02:12   作者:浅挚灬半离兮  
nn.Flatten是一个类,而torch.flatten()则是一个函数,下面这篇文章主要给大家介绍了关于pytorch中nn.Flatten()函数详解及示例的相关资料,需要的朋友可以参考下

torch.nn.Flatten(start_dim=1, end_dim=- 1)

作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。

有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)

同理,如果我这么写:

self.flat = nn.Flatten(start_dim=2, end_dim=3)

那么意思就是从第二维度开始,到第三维度全部给展平,也就是将2、3两个维度展平。

官网给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二个位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,1×5×5]➡[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0, 2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[32×1×5,5]➡[160,5]

因此进行展平后的结果也就是[32,1*5*5]➡[32,25]

示例1

卷积公式

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten())

output = m(input)
print(output.size())

>> torch.Size([32, 288])

示例2

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
    nn.Conv2d(1, 32, 5, 1, 1),  # 通过卷积,得到torch.size([32, 32, 3, 3]
    nn.Flatten(start_dim=0))

output = m(input)
print(output.size())

>>torch.Size([9216])

总结

到此这篇关于pytorch中nn.Flatten()函数详解的文章就介绍到这了,更多相关pytorch nn.Flatten()函数详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python读取MRI并显示为灰度图像实例代码

    Python读取MRI并显示为灰度图像实例代码

    这篇文章主要介绍了Python读取MRI并显示为灰度图像实例代码,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • 完美解决torch.cuda.is_available()一直返回False的玄学方法

    完美解决torch.cuda.is_available()一直返回False的玄学方法

    这篇文章主要介绍了完美解决torch.cuda.is_available()一直返回False的玄学方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • python报错TypeError: Input z must be 2D, not 3D的解决方法

    python报错TypeError: Input z must be 

    大家好,本篇文章主要讲的是python报错TypeError: Input z must be 2D, not 3D的解决方法,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下
    2021-12-12
  • 使用Python制作读单词视频的实现代码

    使用Python制作读单词视频的实现代码

    我们经常在B站或其他视频网站上看到那种逐条读单词的视频,但他们的视频多多少少和我们的预期都不太一致,然而,网上很难找到和自己需求符合的视频,所以本文给大家介绍了使用Python制作读单词视频的实现,需要的朋友可以参考下
    2024-04-04
  •  python中的元类metaclass详情

     python中的元类metaclass详情

    这篇文章主要介绍了python中的metaclass详情,在python中的metaclass就是帮助developer实现元编程,更多详细内容需要的小伙伴可以参考一下
    2022-05-05
  • Python pandas 的索引方式 data.loc[],data[][]示例详解

    Python pandas 的索引方式 data.loc[],data[][]示例详解

    这篇文章主要介绍了Python pandas 的索引方式 data.loc[], data[][]的相关资料,其中data.loc[index,column]使用.loc[ ]第一个参数是行索引,第二个参数是列索引,本文结合实例代码讲解的非常详细,需要的朋友可以参考下
    2023-02-02
  • 快速解释如何使用pandas的inplace参数的使用

    快速解释如何使用pandas的inplace参数的使用

    这篇文章主要介绍了快速解释如何使用pandas的inplace参数的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-07-07
  • Python面向对象原理与基础语法详解

    Python面向对象原理与基础语法详解

    这篇文章主要介绍了Pyhton面向对象原理与基础语法,结合实例形式分析了Python面向对象程序设计中的基本原理、概念、语法与相关使用技巧,需要的朋友可以参考下
    2020-01-01
  • python ChainMap管理用法实例讲解

    python ChainMap管理用法实例讲解

    在本篇文章里小编给大家整理一篇关于python ChainMap的管理用法及相关实例,有需要的朋友们可以学参考下。
    2021-08-08
  • Pytorch torch.repeat_interleave()用法示例详解

    Pytorch torch.repeat_interleave()用法示例详解

    torch.repeat_interleave() 是 PyTorch 中的一个函数,用于按指定的方式重复张量中的元素,这篇文章主要介绍了Pytorch torch.repeat_interleave()用法示例详解,需要的朋友可以参考下
    2024-01-01

最新评论