pytorch中的reshape()、view()、nn.flatten()和flatten()使用

 更新时间:2023年08月02日 10:40:22   作者:梦在黎明破晓时啊  
这篇文章主要介绍了pytorch中的reshape()、view()、nn.flatten()和flatten()使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示。

torch.reshape用法

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,其作用是在不改变tensor元素数目的情况下改变tensor的shape。

torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。

torch.reshape(input, shape) → Tensor

  • input(Tensor)-要重塑的张量
  • shape(python的元组:ints)-新形状`

案例1.

输入:

import torch
a = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)
b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)

结果:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

torch.view用法

view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。

view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。

view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。

一、普通用法(手动调整)

view()相当于reshape、resize,重新调整Tensor的形状。

案例2.

输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)
print(a2)
print(a3)
print(a4)

输出

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

二、特殊用法:参数-1(自动调整size)

view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

输出

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

torch.nn.Flatten(start_dim=1,end_dim=-1)

start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。

因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

使用nn.Flatten(),使用默认参数

官方给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,155]→[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0,2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[3215,5]→[160,25]

torch.flatten

torch.flatten()函数经常用于写分类神经网络的时候,经过最后一个卷积层之后,一般会再接一个自适应的池化层,输出一个BCHW的向量。

这时候就需要用到torch.flatten()函数将这个向量拉平成一个Bx的向量(其中,x = CHW),然后送入到FC层中。

在这里插入图片描述

语句结构

 torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“摊平”的 tensor。

  • start_dim: “摊平”的起始维度。
  • end_dim: “摊平”的结束维度。

作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维。

例1:

import torch
data_pool = torch.randn(2,2,3,3) # 模拟经过最后一个池化层或自适应池化层之后的输出,Batchsize*c*h*w
print(data_pool)
y=torch.flatten(data_pool,1)
print(y)

输出结果:

在这里插入图片描述

结果是一个B*x的向量。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 3个用于数据科学的顶级Python库

    3个用于数据科学的顶级Python库

    今天小编就为大家分享一篇关于3个用于数据科学的顶级Python库,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-09-09
  • python mysqldb连接数据库

    python mysqldb连接数据库

    今天无事想弄下python做个gui开发,最近发布的是python 3k,用到了数据库,通过搜索发现有一个mysqldb这样的控件,可以使用,就去官方看了下结果,没有2.6以上的版本
    2009-03-03
  • Selenium定位浏览器弹窗方法实例总结

    Selenium定位浏览器弹窗方法实例总结

    弹出框是自动化测试中一种常见的元素,这种元素通常是客户端自带的,下面这篇文章主要给大家介绍了关于Selenium定位浏览器弹窗方法的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • Python创建增量目录的代码实例

    Python创建增量目录的代码实例

    这篇文章主要给大家介绍了关于Python创建增量目录的相关资料,文中通过实例代码介绍的非常详细,对大家学习或者使用python具有一定的参考学习价值,需要的朋友可以参考下
    2022-11-11
  • Python格式化处理JSON数据的完整指南

    Python格式化处理JSON数据的完整指南

    在Python中,我们经常需要处理JSON数据,而格式化JSON数据是开发过程中的常见需求,本文将详细介绍如何在Python中对JSON数据进行格式化处理,感兴趣的小伙伴可以跟随小编一起学习一下
    2026-04-04
  • NetworkX之Prim算法(实例讲解)

    NetworkX之Prim算法(实例讲解)

    下面小编就为大家分享一篇NetworkX之Prim算法实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2017-12-12
  • Python函数注释备注的写法和使用方法

    Python函数注释备注的写法和使用方法

    在Python中,为函数添加备注主要通过文档字符串(docstring) 实现,这是Python官方推荐的最佳实践,而非单纯使用#单行注释,下面我会详细介绍不同风格的函数备注写法和使用方法,需要的朋友可以参考下
    2026-01-01
  • Anaconda入门使用总结

    Anaconda入门使用总结

    个人尝试了很多类似的发行版,最终选择了Anaconda,因为其强大而方便的包管理与环境管理的功能。该文主要介绍下Anaconda,对Anaconda的理解,并简要总结下相关的操作
    2018-04-04
  • 解决Django migrate不能发现app.models的表问题

    解决Django migrate不能发现app.models的表问题

    今天小编就为大家分享一篇解决Django migrate不能发现app.models的表问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python绘制汉诺塔

    python绘制汉诺塔

    这篇文章主要为大家详细介绍了python绘制汉诺塔,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-03-03

最新评论