Pytorch中torch.cat()函数的使用及说明

 更新时间:2023年01月03日 10:07:33   作者:cv_lhp  
这篇文章主要介绍了Pytorch中torch.cat()函数的使用及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

torch.cat()函数解析

1. 函数说明

1.1 官网:torch.cat()

函数定义及参数说明如下图所示:

函数定义及参数说明

1.2 函数功能

函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐,如下面例子所示。

torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

2. 代码举例

2.1 输入两个二维张量(dim=0):dim=0对行进行拼接

a = torch.randn(2,3)
b =  torch.randn(3,3)
c = torch.cat((a,b),dim=0)
a,b,c

输出结果如下:

(tensor([[-0.90, -0.37,  1.96],
         [-2.65, -0.60,  0.05]]),
 tensor([[ 1.30,  0.24,  0.27],
         [-1.99, -1.09,  1.67],
         [-1.62,  1.54, -0.14]]),
 tensor([[-0.90, -0.37,  1.96],
         [-2.65, -0.60,  0.05],
         [ 1.30,  0.24,  0.27],
         [-1.99, -1.09,  1.67],
         [-1.62,  1.54, -0.14]]))

2.2 输入两个二维张量(dim=1): dim=1对列进行拼接

a = torch.randn(2,3)
b =  torch.randn(2,4)
c = torch.cat((a,b),dim=1)
a,b,c

输出结果如下:

(tensor([[-0.55, -0.84, -1.60],
         [ 0.39, -0.96,  1.02]]),
 tensor([[-0.83, -0.09,  0.05,  0.17],
         [ 0.28, -0.74, -0.27, -0.85]]),
 tensor([[-0.55, -0.84, -1.60, -0.83, -0.09,  0.05,  0.17],
         [ 0.39, -0.96,  1.02,  0.28, -0.74, -0.27, -0.85]]))

2.3 输入两个三维张量:dim=0 对通道进行拼接

a = torch.randn(2,3,4)
b =  torch.randn(1,3,4)
c = torch.cat((a,b),dim=0)
a,b,c

输出结果如下:

(tensor([[[ 0.51, -0.72, -0.02,  0.76],
          [ 0.72,  1.01,  0.39, -0.13],
          [ 0.37, -0.63, -2.69,  0.74]],
 
         [[ 0.72, -0.31, -0.27,  0.10],
          [ 1.66, -0.06,  1.91, -0.66],
          [ 0.34, -0.23, -0.18, -1.22]]]),
 tensor([[[ 0.94,  0.77, -0.41, -1.20],
          [-0.23, -1.03, -0.25,  1.67],
          [-1.00, -0.68, -0.35, -0.50]]]),
 tensor([[[ 0.51, -0.72, -0.02,  0.76],
          [ 0.72,  1.01,  0.39, -0.13],
          [ 0.37, -0.63, -2.69,  0.74]],
 
         [[ 0.72, -0.31, -0.27,  0.10],
          [ 1.66, -0.06,  1.91, -0.66],
          [ 0.34, -0.23, -0.18, -1.22]],
 
         [[ 0.94,  0.77, -0.41, -1.20],
          [-0.23, -1.03, -0.25,  1.67],
          [-1.00, -0.68, -0.35, -0.50]]]))

2.4 输入两个三维张量:dim=1对行进行拼接

a = torch.randn(2,3,4)
b =  torch.randn(2,4,4)
c = torch.cat((a,b),dim=1)
a,b,c

输出结果如下:

(tensor([[[-0.86,  0.00, -1.26,  1.20],
          [-0.46, -1.08, -0.82,  2.03],
          [-0.89,  0.43,  1.92,  0.49]],
 
         [[ 0.24, -0.02,  0.32,  0.97],
          [ 0.33, -1.34,  0.76, -1.55],
          [ 0.38,  1.45,  0.27, -0.64]]]),
 tensor([[[ 0.82,  0.85, -0.30, -0.58],
          [-0.09,  0.40,  0.02,  0.75],
          [-0.70,  0.67, -0.88, -0.50],
          [-0.62, -1.65, -1.10, -1.39]],
 
         [[-0.85, -1.61, -0.35, -0.56],
          [ 0.00,  1.40,  0.41,  0.39],
          [-0.01,  0.04,  0.80,  0.41],
          [-1.21, -0.64,  1.14,  1.64]]]),
 tensor([[[-0.86,  0.00, -1.26,  1.20],
          [-0.46, -1.08, -0.82,  2.03],
          [-0.89,  0.43,  1.92,  0.49],
          [ 0.82,  0.85, -0.30, -0.58],
          [-0.09,  0.40,  0.02,  0.75],
          [-0.70,  0.67, -0.88, -0.50],
          [-0.62, -1.65, -1.10, -1.39]],
 
         [[ 0.24, -0.02,  0.32,  0.97],
          [ 0.33, -1.34,  0.76, -1.55],
          [ 0.38,  1.45,  0.27, -0.64],
          [-0.85, -1.61, -0.35, -0.56],
          [ 0.00,  1.40,  0.41,  0.39],
          [-0.01,  0.04,  0.80,  0.41],
          [-1.21, -0.64,  1.14,  1.64]]]))

2.5 输入两个三维张量:dim=2对列进行拼接

a = torch.randn(2,3,4)
b =  torch.randn(2,3,5)
c = torch.cat((a,b),dim=2)
a,b,c

输出结果如下:

(tensor([[[ 0.13, -0.02,  0.13, -0.25],
          [ 1.42, -0.22, -0.87,  0.27],
          [-0.07,  1.04, -0.06,  0.91]],
 
         [[ 0.88, -1.46,  0.04,  0.35],
          [ 1.36,  0.64,  0.75,  0.39],
          [ 0.36,  1.13,  0.83,  0.56]]]),
 tensor([[[-0.47, -2.30, -0.49, -1.02,  1.74],
          [ 0.71,  0.89,  0.80, -0.05, -1.35],
          [-0.40,  0.26, -0.78, -1.50, -0.92]],
 
         [[-0.77, -0.01,  1.23,  0.70, -0.66],
          [ 0.28, -0.18, -0.91,  2.23,  1.14],
          [-1.93, -0.17,  0.15,  0.40,  0.32]]]),
 tensor([[[ 0.13, -0.02,  0.13, -0.25, -0.47, -2.30, -0.49, -1.02,  1.74],
          [ 1.42, -0.22, -0.87,  0.27,  0.71,  0.89,  0.80, -0.05, -1.35],
          [-0.07,  1.04, -0.06,  0.91, -0.40,  0.26, -0.78, -1.50, -0.92]],
 
         [[ 0.88, -1.46,  0.04,  0.35, -0.77, -0.01,  1.23,  0.70, -0.66],
          [ 1.36,  0.64,  0.75,  0.39,  0.28, -0.18, -0.91,  2.23,  1.14],
          [ 0.36,  1.13,  0.83,  0.56, -1.93, -0.17,  0.15,  0.40,  0.32]]]))

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 解决pycharm界面不能显示中文的问题

    解决pycharm界面不能显示中文的问题

    今天小编就为大家分享一篇解决pycharm界面不能显示中文的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • django 实现后台从富文本提取纯文本

    django 实现后台从富文本提取纯文本

    这篇文章主要介绍了django 实现后台从富文本提取纯文本,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Jupyter notebook 输出部分显示不全的解决方案

    Jupyter notebook 输出部分显示不全的解决方案

    这篇文章主要介绍了Jupyter notebook 输出部分显示不全的解决方案,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04
  • Python访问MySQL封装的常用类实例

    Python访问MySQL封装的常用类实例

    这篇文章主要介绍了Python访问MySQL封装的常用类,实例详述了针对MySQL使用query执行select及使用update进行insert、delete等操作的方法,需要的朋友可以参考下
    2014-11-11
  • Python线程之认识线程安全 

    Python线程之认识线程安全 

    这篇文章主要介绍了Python线程之认识线程安全,线程安全,名字就非常直接,在多线程情况下是安全的,多线程操作上的安全,下面学习线程安全的文章详细内容,需要的小伙伴可以参考一下
    2022-02-02
  • PyTorch 1.0 正式版已经发布了

    PyTorch 1.0 正式版已经发布了

    今天小编就为大家分享一篇关于PyTorch 1.0 正式版已经发布了!小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-12-12
  • pytest进阶教程之fixture函数详解

    pytest进阶教程之fixture函数详解

    这篇文章主要给大家介绍了关于pytest进阶教程之fixture函数的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • python获取服务器响应cookie的实例

    python获取服务器响应cookie的实例

    今天小编就为大家分享一篇python获取服务器响应cookie的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python输入一个水仙花数(三位数) 输出百位十位个位实例

    python输入一个水仙花数(三位数) 输出百位十位个位实例

    这篇文章主要介绍了python输入一个水仙花数(三位数) 输出百位十位个位实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • python的paramiko模块基本用法详解

    python的paramiko模块基本用法详解

    paramiko 是一个用于在Python中执行远程操作的模块,支持SSH协议,它可以用于连接到远程服务器,执行命令、上传和下载文件,以及在远程服务器上执行各种操作,这篇文章主要介绍了python的paramiko模块基本用法,需要的朋友可以参考下
    2023-08-08

最新评论