关于torch.scatter与torch_scatter库的使用整理

 更新时间:2023年09月11日 14:36:18   作者:回炉重造P  
这篇文章主要介绍了关于torch.scatter与torch_scatter库的使用整理,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

最近在做图结构相关的算法,scatter能把邻接矩阵里的信息修改,或者把邻居分组算个sum或者reduce,挺方便的,简单整理一下。

torch.scatter 与 tensor._scatter

Pytorch自带的函数,用来将作为 src 的tensor根据 index 的描述填充到 input 中,

形式如下:

ouput = torch.scatter(input, dim, index, src)
# 或者是
input.scatter_(dim, index, src)

两个方法的功能是相同的,而带下划线的 _scatter 方法是将原tensor input 直接修改了,不带的则会返回一个新的tensor output input 不变。

其中 dim 决定 index 对应值是沿着哪个维度进行修改。而 src 为数据来源,当其为tensor张量时,shape要和index相同,这样index中每个元素都能对应 src 中对应位置的信息。

理解 scatter 方法主要是要理解 index 实现的 src input 之间的位置对应关系,举个例子:

dim = 0
index = torch.tensor(
	[[0, 2, 2], 
	[2, 1, 0]]
)

dim 为0时,遵循的映射原则为: input[index[i][j]][j] = src[i][j] .

也就是说,将位置 (i, j) 中 dim 对应的位置改为 index[i][j] 的值。

如位置(1,0),index[1][0]为2,则映射后的位置为(2,0),意味着 input 中(2,0)的位置被更改为 src 中(1,0)位置的值。

我个人形象理解是这些值会沿着dim方向滑动,上面例子中src[1][0]位置的值滑到2,成为input中的新值,这样理解起来更形象一点。

基本理解了上面这个例子,多维情况和不同dim的情况都可以类推了。

需要注意:src和input的dtype需要相同,不然会报

Expected self.dtype to be equal to src.dtype

不一样就先转换再使用。

t = torch.arange(6).view(2, 3)
t = t.to(torch.float32)
print(t)
output = torch.scatter(torch.zeros((3, 3)), 0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t)
print(torch.zeros((3, 3)).scatter_(0, torch.tensor([[0, 2, 2], [2, 1, 0]]), t))

输出:

tensor([[0., 1., 2.],
        [3., 4., 5.]])
tensor([[0., 0., 5.],
        [0., 4., 0.],
        [3., 1., 2.]])

torch_scatter库

这个第三方库对矩阵的分组处理这个概念做了更进一步的封装,通过index来指定分组信息,将元素分组后进行对应处理,

最基础的scatter方法形式如下:

torch_scatter.scatter(src, index, dim, out, dim_size, reduce)
  • src : 数据源
  • index :分组序列
  • dim :分组遵循的维度
  • out :输出的tensor,可以不指定直接让函数输出
  • dim_size :out不指定的时候,将输出shape变为该值大小;dim_size也不指定,就根据计算结果来
  • reduce :分组的操作,包括sum,mul,mean,min和max操作

这个方法理解关键在 index 的分组方法,

举个例子:

dim = 1
index = torch.tensor([[0, 1, 1]])

torch_scatter.scatter index 的顺序是没有特定规定的,相同数字对应的元素即为一组。

比如例子中,维度1上的第0个元素为一组,第1和2元素为另一组。

这样,按照分组进行reduce定义的计算即可获得输出。如:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1]]), dim=1, reduce='sum')
print(t_s)

输出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 3,  9],
        [ 6, 15]])

可以看出,每行的后两个元素求了和,与index定义相同。

要注意的是,index的 shape[0] 为1时,会自动对dim对应的维度上每一层进行相同的分组处理,如上例所示,index大小为(1, 3),即对src的三行数据都进行了分组处理。

而另一种分组方式,如需要每行分组不同,则需要index的shape和src的shape相同,如下例:

t = torch.arange(12).view(4, 3)
print(t)
t_s = torch_scatter.scatter(t, torch.tensor([[0, 1, 1], [1, 1, 0], [0, 1, 1], [1, 1, 0]]), dim=1, reduce='sum')
print(t_s)

输出:

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
tensor([[ 0,  3],
        [ 5,  7],
        [ 6, 15]])

shape不相同时,则会报错提示:

RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0 .

同时,该库还给出了另外两种方法,分别为 torch_scatter.segment_coo torch_scatter.segment_csr .

torch_scatter.segment_coo

torch_scatter.segment_coo scatter 的功能差不多,但它只支持index的shape[0]为1的状态,即每一行都为相同的分组方式。

同时,index中数值为顺序排列,以提高计算速度。

torch_scatter.segment_csr

torch_scatter.segment_csr 的index格式不太相同,是一种区间格式,如[0, 2, 5],表示0,1为一组,2,3,4为一组,即取数值间的左闭右开区间。

这个方法是计算速度最快的。

官方文档地址

torch_scatter库doc

https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html

torch.scatter文档

https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python编写简易聊天室实现局域网内聊天功能

    python编写简易聊天室实现局域网内聊天功能

    这篇文章主要为大家详细介绍了python编写简易聊天室实现局域网内聊天功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-07-07
  • pycharm 实现调试窗口恢复

    pycharm 实现调试窗口恢复

    这篇文章主要介绍了pycharm 实现调试窗口恢复的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • 浅谈python日志的配置文件路径问题

    浅谈python日志的配置文件路径问题

    下面小编就为大家分享一篇浅谈python日志的配置文件路径问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 分析Python list操作为什么会错误

    分析Python list操作为什么会错误

    这篇文章主要介绍了分析Python list操作为什么会错误,python搞数据分析,在很多方面python有着比Matlab更大的优势,下面来看看文章具体介绍的相关内容吧,需要的朋友可以参考一下
    2021-11-11
  • Python中利用LSTM模型进行时间序列预测分析的实现

    Python中利用LSTM模型进行时间序列预测分析的实现

    这篇文章主要介绍了Python中利用LSTM模型进行时间序列预测分析的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • python通过post提交数据的方法

    python通过post提交数据的方法

    这篇文章主要介绍了python通过post提交数据的方法,涉及Python使用post方式传递数据的相关技巧,需要的朋友可以参考下
    2015-05-05
  • Pandas分组聚合之groupby()、agg()方法的使用教程

    Pandas分组聚合之groupby()、agg()方法的使用教程

    今天看到pandas的聚合函数agg,比较陌生,平时的工作中处理数据的时候使用的也比较少,为了加深印象,总结一下使用的方法,下面这篇文章主要给大家介绍了关于Pandas分组聚合之groupby()、agg()方法的使用教程,需要的朋友可以参考下
    2023-01-01
  • 如何在python中使用selenium的示例

    如何在python中使用selenium的示例

    这篇文章主要介绍了如何在python中使用selenium的示例,selenium提供了一个通用的接口,可模拟用户来操作浏览器,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-12-12
  • python opencv 读取本地视频文件 修改ffmpeg的方法

    python opencv 读取本地视频文件 修改ffmpeg的方法

    今天小编就为大家分享一篇python opencv 读取本地视频文件 修改ffmpeg的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python合并pdf文件的工具

    Python合并pdf文件的工具

    PDF文件合并工具是非常好用可以把多个pdf文件合并成一个,本文以5个pdf文件为例给大家分享具体操作方法,通过实例代码给大家介绍的非常详细,需要的朋友参考下吧
    2021-07-07

最新评论