聊聊Pytorch torch.cat与torch.stack的区别

 更新时间:2021年05月20日 11:05:21   作者:Winner3  
这篇文章主要介绍了Pytorch torch.cat与torch.stack的区别说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。

torch.cat()的示例如下图1所示

图1 torch.cat()

torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

torch.stack()的示例如下图2所示:

图2 torch.stack()

补充:torch.stack()的官方解释,详解以及例子

可以直接看最下面的【3.例子】,再回头看前面的解释

在pytorch中,常见的拼接函数主要是两个,分别是:

1、stack()

2、cat()

实际使用中,这两个函数互相辅助:关于cat()参考torch.cat(),但是本文主要说stack()。

函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。

形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。

该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

1. stack()

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

outputs = torch.stack(inputs, dim=?) → Tensor

参数

inputs : 待连接的张量序列。

注:python的序列数据只有list和tuple。

dim : 新的维度, 必须在0到len(outputs)之间。

注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

2. 重点

函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小

不懂的看例子,再回过头看就懂了。

3. 例子

1.准备2个tensor数据,每个的shape都是[3,3]

# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
          [40, 50, 60],
          [70, 80, 90]])

2.测试stack函数

print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

可以复制代码运行试试:拼接后的tensor形状,会根据不同的dim发生变化。

dim shape
0 [2, 3, 3]
1 [3, 2, 3]
2 [3, 3, 2]
3 溢出报错

4. 总结

1、函数作用:

函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。

2、存在意义:

在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。

函数存在意义?》》》

手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_len这个时间步]和–[2.张量属性[batch_size, output_size]]。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python闭包、深浅拷贝、垃圾回收、with语句知识点汇总

    python闭包、深浅拷贝、垃圾回收、with语句知识点汇总

    在本篇文章里小编给大家整理了关于python闭包、深浅拷贝、垃圾回收、with语句知识点汇总,有兴趣的朋友们学习下。
    2020-03-03
  • Python logging模块学习笔记

    Python logging模块学习笔记

    这篇文章主要介绍了Python logging模块,logging模块是在2.3新引进的功能,用来处理程序运行中的日志管理,本文详细讲解了该模块的一些常用的类和模块级函数,需要的朋友可以参考下
    2014-05-05
  • python制作爬虫爬取京东商品评论教程

    python制作爬虫爬取京东商品评论教程

    本文是继前2篇Python爬虫系列文章的后续篇,给大家介绍的是如何使用Python爬取京东商品评论信息的方法,并根据数据绘制成各种统计图表,非常的细致,有需要的小伙伴可以参考下
    2016-12-12
  • python机器学习之贝叶斯分类

    python机器学习之贝叶斯分类

    这篇文章主要为大家详细介绍了python机器学习之贝叶斯分类的相关资料,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • python palywright库基本使用

    python palywright库基本使用

    这篇文章主要介绍了python palywright库的基本使用,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-01-01
  • 利用Python写个摸鱼监控进程

    利用Python写个摸鱼监控进程

    继打游戏、看视频等摸鱼行为被监控后,现在打工人离职的倾向也会被监控。今天就带大家领略一下怎么写几行Python代码,就能监控电脑,感兴趣的可以学习一下
    2022-02-02
  • TensorFlow dataset.shuffle、batch、repeat的使用详解

    TensorFlow dataset.shuffle、batch、repeat的使用详解

    今天小编就为大家分享一篇TensorFlow dataset.shuffle、batch、repeat的使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Django日志和调试工具栏实现高效的应用程序调试和性能优化

    Django日志和调试工具栏实现高效的应用程序调试和性能优化

    这篇文章主要介绍了Django日志和调试工具栏实现高效的应用程序调试和性能优化,Django日志和调试工具栏为开发者提供了快速定位应用程序问题的工具,可提高调试和性能优化效率,提高应用程序的可靠性和可维护性
    2023-05-05
  • Python3 微信支付(小程序支付)V3接口的实现

    Python3 微信支付(小程序支付)V3接口的实现

    本文主要介绍了Python3 微信支付(小程序支付)V3接口的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • 详解Python发送邮件实例

    详解Python发送邮件实例

    这篇文章主要介绍了Python发送邮件实例,Python发送邮件需要smtplib和email两个模块,感兴趣的小伙伴们可以参考一下
    2016-01-01

最新评论