关于torch.flatten()函数及x=x.view()函数的理解

 更新时间:2025年04月10日 17:31:13   作者:浩瀚之水_csdn  
这篇文章主要介绍了关于torch.flatten()函数及x=x.view()函数的理解,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

一、x = x.view()

x = x.view(x.size(0), -1)

在PyTorch中,x.view(x.size(0), -1)是一种常用的操作,用于改变张量(Tensor)的形状而不改变其数据。

这里的x是一个多维张量,而.view()函数是用来重新塑形这个张量的,同时保持其元素的总数不变。

具体来说,x.view(x.size(0), -1)的含义是:

  • x.size(0):这部分获取了张量x的第一个维度的大小(即,如果x是一个形状为(a, b, c)的张量,那么x.size(0)就等于a)。在大多数情况下,这代表了批处理中的样本数或者是序列的长度,取决于上下文。
  • -1:在.view()函数中,-1是一个特殊的值,表示该维度的大小会自动计算,以便保持总元素数不变。换句话说,PyTorch会根据其他维度的大小和总元素数来推断出-1应该代表的具体数值。

因此,x.view(x.size(0), -1)的作用是将张量x重新塑形为一个二维张量,其中第一维的大小保持不变(即原始张量的第一个维度的大小),而第二维的大小则自动调整,以包含所有剩余的元素。

这种操作在需要将多维数据“展平”为二维数据以进行某些操作(如全连接层)时非常有用。

例如,如果x是一个形状为(64, 3, 28, 28)的张量(通常表示一个包含64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素的数据集),那么x.view(x.size(0), -1)将会把x重新塑形为一个形状为(64, 3*28*28)的张量,其中每个样本都被展平成了一个长向量。

二、torch.flatten()函数

x = torch.flatten(x, start_dim=0, end_dim=2)

x = torch.flatten(x, 0)

当你使用 x = torch.flatten(x, 0) 时,这里的 0start_dim 参数的值,而 end_dim 参数仍然默认为 -1。这意呀着展平操作将从张量 x 的第一个维度(索引为0的维度)开始,并且一直进行到张量的最后一个维度。

然而,由于 start_dim 被设置为0,并且 end_dim 默认为 -1,实际上这会将整个张量 x 完全展平为一个一维张量。换句话说,无论原始张量 x 的形状如何,调用 torch.flatten(x, 0) 后,x 将变成一个一维张量,其长度等于原始张量中所有元素的总数。

例如,如果原始张量 x 的形状是 (a, b, c, d),那么调用 x = torch.flatten(x, 0) 后,x 的新形状将是 (a*b*c*d,),即一个包含 a*b*c*d 个元素的一维张量。

这种完全展平的操作在需要将多维数据转换为适合某些特定操作(如完全连接层的前馈传播)的一维形式时非常有用。然而,它也意味着你丢失了原始数据的形状信息,除非你在其他地方记录了这些信息或者你的操作不需要保留这些形状信息。

x = torch.flatten(x, 1)

在PyTorch中,torch.flatten(x, start_dim=0, end_dim=-1)函数用于将张量x在指定的维度范围内展平(或扁平化),而不改变其数据。这里的start_dim是开始展平的维度(包含该维度),end_dim是结束展平的维度(不包含该维度),默认情况下end_dim为-1,即最后一个维度。

当你使用x = torch.flatten(x, 1)时,你告诉PyTorch从第二个维度(索引为1,因为索引是从0开始的)开始,一直到最后一个维度,将所有的这些维度都展平成一个维度。这意味着,如果x是一个多维张量,那么除了第一个维度之外的所有维度都将被合并成一个维度。

例如,如果x的形状是(64, 3, 28, 28)(代表64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素),那么x = torch.flatten(x, 1)将会把x展平成一个形状为(64, 3*28*28)的张量。这里,第一个维度(样本数64)保持不变,而剩下的三个维度(3, 28, 28)被合并成了一个维度。

这种操作在处理图像数据时特别有用,尤其是在需要将图像数据传递给全连接层之前,因为全连接层通常期望输入是二维的(尽管在实践中,通常会先通过一个或多个卷积层来处理图像数据)。通过展平操作,你可以将多维的图像数据转换成二维的形式,以便进行后续处理。

x = torch.flatten(x, 2)

在PyTorch中,torch.flatten(input, start_dim=0, end_dim=-1) 函数用于将多维张量(tensor)展平(flatten)为一维张量,但你可以通过指定start_dimend_dim参数来控制从哪一维度开始展平,以及在哪一维度结束(不包括该维度)。这意味着你可以保留张量的某些维度不变,而将其他维度展平。

对于你的代码 x = torch.flatten(x, 2),这里:

  • x 是你想要展平的原始张量。
  • 2start_dim参数的值,而end_dim参数默认为-1,表示展平操作会一直进行到张量的最后一个维度。

因此,torch.flatten(x, 2) 的意思是从张量x的第3维(因为索引从0开始)开始,将之后的所有维度都展平成一个维度。如果x的形状是例如 (a, b, c, d, e),那么torch.flatten(x, 2)之后,x的形状将变为 (a, b, c*d*e)。这里,ab维度保持不变,而cde三个维度被合并成了一个新的维度。

这种操作在处理多维数据时非常有用,特别是当你需要将一部分数据的维度保持不变,而将其他部分数据“展平”以便于后续处理(如全连接层处理)时。

三、示例

import torch

A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
# print(A.size)
print(A.shape)

B = torch.flatten(A,1)
print(B.shape)
# print(B)

C = torch.flatten(A,0,1)
print(C.shape)
# print(C)

D = torch.flatten(A,2)
print(D.shape)
# print(D)

E = torch.flatten(A,0)
print(E.shape)
# print(E)

F = A.view(A.size(0), -1)
print(F.shape)
# print(F)

G = A.view(A.size(0), -1, 1)
print(G.shape)

输出:

torch.Size([2, 3, 4])
torch.Size([2, 12])
torch.Size([6, 4])
torch.Size([2, 3, 4])
torch.Size([24])
torch.Size([2, 12])
torch.Size([2, 12, 1])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用Python对Excel进行读写操作

    使用Python对Excel进行读写操作

    学习Python的过程中,我们会遇到Excel的读写问题。这时,我们可以使用xlwt模块将数据写入Excel表格中,使用xlrd模块从Excel中读取数据。下面我们介绍如何实现使用Python对Excel进行读写操作。
    2017-03-03
  • 关于pytorch中网络loss传播和参数更新的理解

    关于pytorch中网络loss传播和参数更新的理解

    今天小编就为大家分享一篇关于pytorch中网络loss传播和参数更新的理解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • 用python实现刷点击率的示例代码

    用python实现刷点击率的示例代码

    今天小编就为大家分享一篇用python实现刷点击率的示例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python游戏开发之Pygame使用的最全教程分享

    Python游戏开发之Pygame使用的最全教程分享

    Pygame库是Python中一个专为游戏开发设计的库,它提供了大量的功能来帮助开发者创建各种2D游戏,本文就来和大家分享一下Pygame的具体使用,希望对大家有所帮助
    2023-05-05
  • 基于Python实现一键获取电脑浏览器的账号密码

    基于Python实现一键获取电脑浏览器的账号密码

    发现很多人在学校图书馆喜欢用电脑占座,而且出去的时候经常不锁屏,为了让大家养成良好的习惯,本文将分享一个小程序,可以快速获取你存储在电脑浏览器中的所有账号和密码,感兴趣的可以了解一下
    2022-05-05
  • Pandas0.25来了千万别错过这10大好用的新功能

    Pandas0.25来了千万别错过这10大好用的新功能

    这篇文章主要介绍了Pandas0.25来了千万别错过这10大好用的新功能,都有哪些新功能,文中给大家详细介绍,需要的朋友可以参考下
    2019-08-08
  • Pandas排序和分组排名(sort和rank)的实现

    Pandas排序和分组排名(sort和rank)的实现

    Pandas是Python中广泛使用的数据处理库,提供了丰富的功能来处理和分析数据,本文主要介绍了Pandas排序和分组排名(sort和rank)的实现,具有一定的参考价值,感兴趣的可以了解一下
    2024-07-07
  • 如何使用Python异步之上下文管理器

    如何使用Python异步之上下文管理器

    这篇文章主要为大家介绍了如何使用Python异步之上下文管理器详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • Python实现的对一个数进行因式分解操作示例

    Python实现的对一个数进行因式分解操作示例

    这篇文章主要介绍了Python实现的对一个数进行因式分解操作,结合实例形式分析了Python因式分解数值运算相关操作技巧,需要的朋友可以参考下
    2019-06-06
  • python生成png的方法

    python生成png的方法

    本文主要介绍了python生成png的方法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-10-10

最新评论