关于torch.flatten()函数及x=x.view()函数的理解

 更新时间:2025年04月10日 17:31:13   作者:浩瀚之水_csdn  
这篇文章主要介绍了关于torch.flatten()函数及x=x.view()函数的理解,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

一、x = x.view()

x = x.view(x.size(0), -1)

在PyTorch中,x.view(x.size(0), -1)是一种常用的操作,用于改变张量(Tensor)的形状而不改变其数据。

这里的x是一个多维张量,而.view()函数是用来重新塑形这个张量的,同时保持其元素的总数不变。

具体来说,x.view(x.size(0), -1)的含义是:

  • x.size(0):这部分获取了张量x的第一个维度的大小(即,如果x是一个形状为(a, b, c)的张量,那么x.size(0)就等于a)。在大多数情况下,这代表了批处理中的样本数或者是序列的长度,取决于上下文。
  • -1:在.view()函数中,-1是一个特殊的值,表示该维度的大小会自动计算,以便保持总元素数不变。换句话说,PyTorch会根据其他维度的大小和总元素数来推断出-1应该代表的具体数值。

因此,x.view(x.size(0), -1)的作用是将张量x重新塑形为一个二维张量,其中第一维的大小保持不变(即原始张量的第一个维度的大小),而第二维的大小则自动调整,以包含所有剩余的元素。

这种操作在需要将多维数据“展平”为二维数据以进行某些操作(如全连接层)时非常有用。

例如,如果x是一个形状为(64, 3, 28, 28)的张量(通常表示一个包含64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素的数据集),那么x.view(x.size(0), -1)将会把x重新塑形为一个形状为(64, 3*28*28)的张量,其中每个样本都被展平成了一个长向量。

二、torch.flatten()函数

x = torch.flatten(x, start_dim=0, end_dim=2)

x = torch.flatten(x, 0)

当你使用 x = torch.flatten(x, 0) 时,这里的 0start_dim 参数的值,而 end_dim 参数仍然默认为 -1。这意呀着展平操作将从张量 x 的第一个维度(索引为0的维度)开始,并且一直进行到张量的最后一个维度。

然而,由于 start_dim 被设置为0,并且 end_dim 默认为 -1,实际上这会将整个张量 x 完全展平为一个一维张量。换句话说,无论原始张量 x 的形状如何,调用 torch.flatten(x, 0) 后,x 将变成一个一维张量,其长度等于原始张量中所有元素的总数。

例如,如果原始张量 x 的形状是 (a, b, c, d),那么调用 x = torch.flatten(x, 0) 后,x 的新形状将是 (a*b*c*d,),即一个包含 a*b*c*d 个元素的一维张量。

这种完全展平的操作在需要将多维数据转换为适合某些特定操作(如完全连接层的前馈传播)的一维形式时非常有用。然而,它也意味着你丢失了原始数据的形状信息,除非你在其他地方记录了这些信息或者你的操作不需要保留这些形状信息。

x = torch.flatten(x, 1)

在PyTorch中,torch.flatten(x, start_dim=0, end_dim=-1)函数用于将张量x在指定的维度范围内展平(或扁平化),而不改变其数据。这里的start_dim是开始展平的维度(包含该维度),end_dim是结束展平的维度(不包含该维度),默认情况下end_dim为-1,即最后一个维度。

当你使用x = torch.flatten(x, 1)时,你告诉PyTorch从第二个维度(索引为1,因为索引是从0开始的)开始,一直到最后一个维度,将所有的这些维度都展平成一个维度。这意味着,如果x是一个多维张量,那么除了第一个维度之外的所有维度都将被合并成一个维度。

例如,如果x的形状是(64, 3, 28, 28)(代表64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素),那么x = torch.flatten(x, 1)将会把x展平成一个形状为(64, 3*28*28)的张量。这里,第一个维度(样本数64)保持不变,而剩下的三个维度(3, 28, 28)被合并成了一个维度。

这种操作在处理图像数据时特别有用,尤其是在需要将图像数据传递给全连接层之前,因为全连接层通常期望输入是二维的(尽管在实践中,通常会先通过一个或多个卷积层来处理图像数据)。通过展平操作,你可以将多维的图像数据转换成二维的形式,以便进行后续处理。

x = torch.flatten(x, 2)

在PyTorch中,torch.flatten(input, start_dim=0, end_dim=-1) 函数用于将多维张量(tensor)展平(flatten)为一维张量,但你可以通过指定start_dimend_dim参数来控制从哪一维度开始展平,以及在哪一维度结束(不包括该维度)。这意味着你可以保留张量的某些维度不变,而将其他维度展平。

对于你的代码 x = torch.flatten(x, 2),这里:

  • x 是你想要展平的原始张量。
  • 2start_dim参数的值,而end_dim参数默认为-1,表示展平操作会一直进行到张量的最后一个维度。

因此,torch.flatten(x, 2) 的意思是从张量x的第3维(因为索引从0开始)开始,将之后的所有维度都展平成一个维度。如果x的形状是例如 (a, b, c, d, e),那么torch.flatten(x, 2)之后,x的形状将变为 (a, b, c*d*e)。这里,ab维度保持不变,而cde三个维度被合并成了一个新的维度。

这种操作在处理多维数据时非常有用,特别是当你需要将一部分数据的维度保持不变,而将其他部分数据“展平”以便于后续处理(如全连接层处理)时。

三、示例

import torch

A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
# print(A.size)
print(A.shape)

B = torch.flatten(A,1)
print(B.shape)
# print(B)

C = torch.flatten(A,0,1)
print(C.shape)
# print(C)

D = torch.flatten(A,2)
print(D.shape)
# print(D)

E = torch.flatten(A,0)
print(E.shape)
# print(E)

F = A.view(A.size(0), -1)
print(F.shape)
# print(F)

G = A.view(A.size(0), -1, 1)
print(G.shape)

输出:

torch.Size([2, 3, 4])
torch.Size([2, 12])
torch.Size([6, 4])
torch.Size([2, 3, 4])
torch.Size([24])
torch.Size([2, 12])
torch.Size([2, 12, 1])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python如何生成各种随机分布图

    python如何生成各种随机分布图

    这篇文章主要为大家详细介绍了python如何生成各种随机分布图,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-08-08
  • Python Flask-Login构建强大的用户认证系统实例探究

    Python Flask-Login构建强大的用户认证系统实例探究

    这篇文章主要为大家介绍了Python Flask-Login构建强大的用户认证系统示例探究,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • python数据分析近年比特币价格涨幅趋势分布

    python数据分析近年比特币价格涨幅趋势分布

    这篇文章主要为大家介绍了python分析近年来比特币价格涨幅趋势的数据分布,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步
    2021-11-11
  • Python使用GitPython操作Git版本库的方法

    Python使用GitPython操作Git版本库的方法

    这篇文章主要介绍了Python使用GitPython操作Git版本库的方法,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-02-02
  • Appium+python自动化之连接模拟器并启动淘宝APP(超详解)

    Appium+python自动化之连接模拟器并启动淘宝APP(超详解)

    这篇文章主要介绍了Appium+python自动化之 连接模拟器并启动淘宝APP(超详解)本文以淘宝app为例,通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2019-06-06
  • python利用正则表达式提取字符串

    python利用正则表达式提取字符串

    相信大家在日常工作中经常会遇见在文本中提取特定位置字符串的需求,python的正则性很好,很适合做这类字符串的提取,所以这篇文章就给大家详细讲一下提取的技巧,并通过示例代码讲解,对大家理解很有帮助,有需要的朋友们下面来一起学习学习吧。
    2016-12-12
  • 使用Pytorch构建第一个神经网络模型 附案例实战

    使用Pytorch构建第一个神经网络模型 附案例实战

    这篇文章主要介绍了用Pytorch构建第一个神经网络模型(附案例实战),本文通过实例代码给大家讲解的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-03-03
  • 自学python求已知DNA模板的互补DNA序列

    自学python求已知DNA模板的互补DNA序列

    这篇文章主要为大家介绍了自学python求已知DNA模板的互补DNA序列的示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Python中列表(List) 的三种遍历(序号和值)方法小结

    Python中列表(List) 的三种遍历(序号和值)方法小结

    这篇文章主要介绍了Python中列表(List) 的三种遍历(序号和值)方法小结,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • 对python中Json与object转化的方法详解

    对python中Json与object转化的方法详解

    今天小编就为大家分享一篇对python中Json与object转化的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12

最新评论