PyTorch中关于tensor.repeat()的使用
关于tensor.repeat()的使用
考虑到很多人在学习这个函数,我想在这里提 一个建议:
强烈推荐 使用 einops 模块中的 repeat() 函数 替代 tensor.repeat()!
它可以摆脱 tensor.repeat() 参数的神秘主义。
einops 模块文档地址:https://nbviewer.jupyter.org/github/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb
学习 tensor.repeat() 这个函数的功能的时候,最好还是要观察所得到的 结果的维度。
不多说,看代码:
>>> import torch >>> >>> # 定义一个 33x55 张量 >>> a = torch.randn(33, 55) >>> a.size() torch.Size([33, 55]) >>> >>> # 下面开始尝试 repeat 函数在不同参数情况下的效果 >>> a.repeat(1,1).size() # 原始值:torch.Size([33, 55]) torch.Size([33, 55]) >>> >>> a.repeat(2,1).size() # 原始值:torch.Size([33, 55]) torch.Size([66, 55]) >>> >>> a.repeat(1,2).size() # 原始值:torch.Size([33, 55]) torch.Size([33, 110]) >>> >>> a.repeat(1,1,1).size() # 原始值:torch.Size([33, 55]) torch.Size([1, 33, 55]) >>> >>> a.repeat(2,1,1).size() # 原始值:torch.Size([33, 55]) torch.Size([2, 33, 55]) >>> >>> a.repeat(1,2,1).size() # 原始值:torch.Size([33, 55]) torch.Size([1, 66, 55]) >>> >>> a.repeat(1,1,2).size() # 原始值:torch.Size([33, 55]) torch.Size([1, 33, 110]) >>> >>> a.repeat(1,1,1,1).size() # 原始值:torch.Size([33, 55]) torch.Size([1, 1, 33, 55]) >>> >>> # ------------------ 割割 ------------------ >>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数, >>> # 下面是一些错误示例 >>> a.repeat(2).size() # 1D < 2D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> # 定义一个3维的张量,然后展示前面提到的那个错误 >>> b = torch.randn(5,6,7) >>> b.size() # 3D torch.Size([5, 6, 7]) >>> >>> b.repeat(2).size() # 1D < 3D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> b.repeat(2,1).size() # 2D < 3D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> b.repeat(2,1,1).size() # 3D = 3D, okay torch.Size([10, 6, 7]) >>>
Tensor.repeat()的简单用法
相当于手动实现广播机制,即沿着给定的维度对tensor进行重复:
比如说对下面x的第1个通道复制三次,其余通道保持不变:
import torch x = torch.randn(1, 3, 224, 224) y = x.repeat(3, 1, 1, 1) print(x.shape) print(y.shape)
结果为:
torch.Size([1, 3, 224, 224])
torch.Size([3, 3, 224, 224])
这个在复制batch的时候用的比较多,上面的情况就相当于batch为1的3×224×224特征图复制成了batch为3
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
详解Python logging调用Logger.info方法的处理过程
这篇文章主要介绍了详解Python logging调用Logger.info方法的处理过程,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧2019-02-02浅谈sklearn中predict与predict_proba区别
这篇文章主要介绍了浅谈sklearn中predict与predict_proba区别,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-06-06
最新评论