pytorch中unsqueeze用法小结

 更新时间:2024年04月17日 11:08:05   作者:ym62033  
unsqueeze()的作用是用来增加给定tensor的维度的,本文主要介绍了pytorch中unsqueeze用法小结,具有一定的参考价值,感兴趣的可以了解一下

在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度

需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。

import torch

a = torch.randn(2,5)
print(a)

print("")
b = a.unsqueeze(0)
print(b.shape)

print("")
c = a.unsqueeze(a.dim())
print(c.shape)


输出:
tensor([[-0.4734,  0.4115, -0.9415, -1.1280, -0.1065],
        [ 0.1613,  1.2594,  1.1261,  1.3881,  0.1112]])

torch.Size([1, 2, 5])

torch.Size([2, 5, 1])

以上是二维数据情况:

首先生成了一个二维矩阵,其大小为[2,5]

然后,在0维度上插入一个维度,可以看到现在新矩阵a的形状变为[1,2,5],第0维度的大小默认是1

最后,在最后一个维度上插入一个维度,形状变为[2, 5, 1]

a=torch.rand(2,3,2)

print("")
print("torch.unsqueeze(a,3) size: {}".format(torch.unsqueeze(a,3).size()))

print("")
print("torch.unsqueeze(a,2) size: {}".format(torch.unsqueeze(a,2).size()))

print("")
print("torch.unsqueeze(a,1) size: {}".format(torch.unsqueeze(a,1).size()))

print("")
print("torch.unsqueeze(a,0) size: {}".format(torch.unsqueeze(a,0).size()))
 
print("")
print("torch.unsqueeze(a,-1) size: {}".format(torch.unsqueeze(a,-1).size()))

print("")
print("torch.unsqueeze(a,-2) size: {}".format(torch.unsqueeze(a,-2).size()))

print("")
print("torch.unsqueeze(a,-3) size: {}".format(torch.unsqueeze(a,-3).size()))

print("")
print("torch.unsqueeze(a,-4) size: {}".format(torch.unsqueeze(a,-4).size()))

输出:
torch.unsqueeze(a,3) size: torch.Size([2, 3, 2, 1])

torch.unsqueeze(a,2) size: torch.Size([2, 3, 1, 2])

torch.unsqueeze(a,1) size: torch.Size([2, 1, 3, 2])

torch.unsqueeze(a,0) size: torch.Size([1, 2, 3, 2])

torch.unsqueeze(a,-1) size: torch.Size([2, 3, 2, 1])

torch.unsqueeze(a,-2) size: torch.Size([2, 3, 1, 2])

torch.unsqueeze(a,-3) size: torch.Size([2, 1, 3, 2])

torch.unsqueeze(a,-4) size: torch.Size([1, 2, 3, 2])

对于三维数据input.dim() = 3,因此dim的范围是[-4, 4)

torch.squeeze() 和 torch.unsqueeze()区别

第一块:

squeeze(),主要是对数据的维度进行压缩,去掉元素数为1的那个维度,使用方式:a.squeeze(N) or torch.squeeze(a,N) ,去掉a的第N维度,以此来实现数据a的维度压缩;

unsqueeze()与squeeze()函数功能相反,其功能是对数据维度进行扩充,使用方式:a.unsqueeze(N) or torch.unsqueeze(a,N),在数据a的第N维度上增加一个维数为1的维度,以此实现对数据的扩充,方便后续模型训练喂入模型的数据的维度和模型接收数据的维度是匹配的。

第二块:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # 选择第0个cuda

model.to(device)

以上两行代码放在读取数据之前。

mytensor = my_tensor.to(device) #将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上,之后运算都在指定的GPU上进行。这些tensor多是最开始读取数据时的变量,后面其衍生出的新变量也会在已指定的GPU上运行计算。

第三块:

Tensor & Numpy 都是矩阵,区别在与Tensor可以在GPU上运行,Numpy只能在CPU上运行。(天呐,我现在才知道!)Tensor与Numpy互相转化很方便,类型也比较兼容,Tensor可以直接通过print显示数据类型,而Numpy不可以。

第四块:

x.aadd(y) 实现x与y Tensor的相加,不改变x,返回一个新的Tensor

x.add_(y)  实现x与y Tensor的相加,会修改x的维数

到此这篇关于pytorch中unsqueeze用法小结的文章就介绍到这了,更多相关pytorch unsqueeze内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python 基于线程的并行 threading模块的用法

    Python 基于线程的并行 threading模块的用法

    threading模块是Python的高级线程接口,提供线程对象和同步工具,本文主要介绍了Python 基于线程的并行 threading模块的用法,感兴趣的可以了解一下
    2025-06-06
  • pyttsx3实现中文文字转语音的方法

    pyttsx3实现中文文字转语音的方法

    今天小编就为大家分享一篇pyttsx3实现中文文字转语音的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python报错TypeError: object of type ‘generator‘ has no len ()的解决方法

    Python报错TypeError: object of type ‘gener

    在Python开发的复杂世界中,报错信息就像神秘的谜题,困扰着开发者和环境配置者,其中,TypeError: object of type ‘generator’ has no len()这个报错,常常在不经意间打乱我们的开发节奏,本文让我们一起深入探究这个报错问题,为Python开发之路扫除障碍
    2024-10-10
  • 使用Django清空数据库并重新生成

    使用Django清空数据库并重新生成

    这篇文章主要介绍了使用Django清空数据库并重新生成,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Python调用OpenAI Agents SDK打造一个多智能体系统

    Python调用OpenAI Agents SDK打造一个多智能体系统

    还在手搓Agent通信逻辑吗,OpenAI官方SDK让你用纯Python代码构建生产级多智能体系统,本文从零到一,基于Python调用OpenAI Agents SDK打造你的第一个多智能体系统,需要的朋友可以参考下
    2026-04-04
  • python操作注册表的方法实现

    python操作注册表的方法实现

    Python提供了winreg模块,可以用于操作Windows注册表,本文就来介绍一下python操作注册表的方法实现,主要包括打开注册表、读取注册表值、写入注册表值和关闭注册表,具有一定的参考价值,感兴趣的可以了解一下
    2023-08-08
  • python获取标准北京时间的方法

    python获取标准北京时间的方法

    这篇文章主要介绍了python获取标准北京时间的方法,实例分析了Python通过www.beijing-time.org的官网获取标准北京时间的技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-03-03
  • Python基于pandas爬取网页表格数据

    Python基于pandas爬取网页表格数据

    这篇文章主要介绍了Python基于pandas获取网页表格数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • 基于python实现微信收红包自动化测试脚本(测试用例)

    基于python实现微信收红包自动化测试脚本(测试用例)

    这篇文章主要介绍了基于python实现微信收红包自动化测试脚本,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2021-07-07
  • Python自定义logger模块的实例代码

    Python自定义logger模块的实例代码

    Python标准库中的logging模块提供了日志记录的功能,自定义 Logger 可以根据项目的需求定制化日志记录,满足特定的日志记录格式、输出目标和日志级别等要求,本文给大家介绍了Python自定义logger模块的实例代码,需要的朋友可以参考下
    2024-02-02

最新评论