支持PyTorch的einops张量操作神器用法示例详解

 更新时间:2021年11月01日 17:26:06   作者:木盏  
这篇文章主要为大家介绍了支持PyTorch的einops张量操作神器用法示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步早日升职加薪

今天做visual transformer研究的时候,发现了einops这么个神兵利器,决定大肆安利一波。

先看链接:https://github.com/arogozhnikov/einops

安装:

pip install einops

基础用法

einops的强项是把张量的维度操作具象化,让开发者“想出即写出”。举个例子:

from einops import rearrange
 
# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 'h w c -> c h w')

用'h w c -> c h w'就完成了维度调换,这个功能与pytorch中的permute相似。但是,einops的rearrange玩法可以更高级:

from einops import rearrange
import torch
 
a = torch.randn(3, 9, 9)  # [3, 9, 9]
output = rearrange(a, 'c (r p) w -> c r p w', p=3)
print(output.shape)   # [3, 3, 3, 9]

这就是高级用法了,把中间维度看作r×p,然后给出p的数值,这样系统会自动把中间那个维度拆解成3×3。这样就完成了[3, 9, 9] -> [3, 3, 3, 9]的维度转换。

这个功能就不是pytorch的内置功能可比的。

除此之外,还有reduce和repeat,也是很好用。

from einops import repeat
import torch
 
a = torch.randn(9, 9)  # [9, 9]
output_tensor = repeat(a, 'h w -> c h w', c=3)  # [3, 9, 9]

指定c,就可以指定复制的层数了。

再看reduce:

from einops import reduce
import torch
 
a = torch.randn(9, 9)  # [9, 9]
output_tensor = reduce(a, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)

这里的'mean'指定池化方式。 相信你看得懂,不懂可留言提问~

高级用法 

einops也可以嵌套在pytorch的layer里,请看:

# example given for pytorch, but code in other frameworks is almost identical  
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrange
 
model = Sequential(
    Conv2d(3, 6, kernel_size=5),
    MaxPool2d(kernel_size=2),
    Conv2d(6, 16, kernel_size=5),
    MaxPool2d(kernel_size=2),
    # flattening
    Rearrange('b c h w -> b (c h w)'),  
    Linear(16*5*5, 120), 
    ReLU(),
    Linear(120, 10), 
)

这里的Rearrange是nn.module的子类,直接可以当作网络层放到模型里~

一个字,绝。

以上就是支持PyTorch的einops张量操作神器用法示例详解的详细内容,更多关于einops张量操作用法的资料请关注脚本之家其它相关文章!

相关文章

  • Python实现LR1文法的完整实例代码

    Python实现LR1文法的完整实例代码

    这篇文章主要给大家介绍了关于Python实现LR1文法的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-10-10
  • python实现封装得到virustotal扫描结果

    python实现封装得到virustotal扫描结果

    这篇文章主要介绍了python实现封装得到virustotal扫描结果的方法,是比较实用的技巧,可将扫描结果写入数据库,需要的朋友可以参考下
    2014-10-10
  • django rest framework 过滤时间操作

    django rest framework 过滤时间操作

    这篇文章主要介绍了django rest framework 过滤时间操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07
  • Python实现智慧校园自动评教全新版

    Python实现智慧校园自动评教全新版

    上一次的智慧校园自动评教是用的selenium库去模拟人去对浏览器进行点击操作,虽然比手动评教要快,但是效率还是不高.从而想去尝试重新写一份不用selenium的评教方案,功夫不负有心人,最终成功了,需要的朋友可以参考下
    2021-06-06
  • Jupyter Notebook读取csv文件出现的问题及解决

    Jupyter Notebook读取csv文件出现的问题及解决

    这篇文章主要介绍了Jupyter Notebook读取csv文件出现的问题及解决,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-01-01
  • python celery分布式任务队列的使用详解

    python celery分布式任务队列的使用详解

    这篇文章主要介绍了python celery分布式任务队列的使用详解,Celery 是一个 基于python开发的分布式异步消息任务队列,通过它可以轻松的实现任务的异步处理, 如果你的业务场景中需要用到异步任务,就可以考虑使用celery,需要的朋友可以参考下
    2019-07-07
  • C#返回当前系统所有可用驱动器符号的方法

    C#返回当前系统所有可用驱动器符号的方法

    这篇文章主要介绍了C#返回当前系统所有可用驱动器符号的方法,涉及C#操作系统硬件驱动的相关技巧,需要的朋友可以参考下
    2015-04-04
  • python可视化实现KNN算法

    python可视化实现KNN算法

    这篇文章主要为大家详细介绍了python可视化实现KNN算法,通过绘图工具Matplotlib包可视化实现机器学习中的KNN算法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-10-10
  • 使用Numpy读取CSV文件,并进行行列删除的操作方法

    使用Numpy读取CSV文件,并进行行列删除的操作方法

    今天小编就为大家分享一篇使用Numpy读取CSV文件,并进行行列删除的操作方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • python通过ElementTree操作XML获取结点读取属性美化XML

    python通过ElementTree操作XML获取结点读取属性美化XML

    本文讲解如何通过ElementTree解析XML,获取儿子结点、插入儿子结点、操作属性、美化XML
    2013-12-12

最新评论