pytorch之torch_scatter.scatter_max()用法

 更新时间:2023年09月11日 11:45:10   作者:A2333fun  
这篇文章主要介绍了pytorch之torch_scatter.scatter_max()用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

  • 根据index将src分组,求每一组中的最大值输出到out
  • dim是维度

from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
'''src根据index进行分组'''
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)

输出

tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1,  3,  4,  0,  1],
        [ 1,  4,  3, -1, -1, -1]])

解释

torch_scatter.scatter()使用

1. 参数

具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如 sum mean 等,然后将这些操作结果按照索引顺序进行拼接。

下面我用具体的例子来进行讲解。

2. 示例

2.1 简单示例

首先初始化src和index:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)

接着使用scatter函数:

out = scatter(src, index, dim=0, reduce='mean')

我们观察 index=[0, 0, 1] ,第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果 index=[1, 0, 0] ,则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。

那么src中第几个元素到底是如何定义的呢?这就需要用到 dim 参数了。

dim=0 意味着我们需要对src的维度0进行操作:

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

即src中第0个元素为 [1, 2, 3] ,第1个元素为 [4, 5, 6] ,第2个元素为 [7, 8, 9]

而如果 dim=1 ,则第0个元素为 [1, 4, 7] ,第1个元素为 [2, 5, 8] ,第2个元素为 [3, 6, 9]

因此,如果有以下代码:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')

那么我们就应该将src中的第0个元素为 [1, 2, 3] 和第1个元素为 [4, 5, 6] 求平均为 [2.5, 3.5, 4.5] ,然后第2个元素 [7, 8, 9] 保持不变,即:

tensor([[2.5000, 3.5000, 4.5000],
        [7.0000, 8.0000, 9.0000]])

2.2 顺序问题

上面的例子中 index=[0, 0, 1] ,最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。

如果 index=[1, 1, 0] ,结果为:

tensor([[7.0000, 8.0000, 9.0000],
        [2.5000, 3.5000, 4.5000]])

可以发现,上述结果是将src中第2个元素 [7, 8, 9] 保持不变放到了位置0,然后将src中第0个元素 [1, 2, 3] 和第1个元素 [4, 5, 6] 求平均保持不变放到了位置1。

也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。

2.3 维度问题

如果src的维度为(4, 3),而我们需要对 dim=0 操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

报错为:

RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]

正确做法应该是:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

输出为:

tensor([[ 7.0000,  8.0000,  9.0000],
        [ 2.5000,  3.5000,  4.5000],
        [10.0000, 11.0000, 12.0000]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Opencv实现计算两条直线或线段角度方法详解

    Opencv实现计算两条直线或线段角度方法详解

    这篇文章主要介绍了Opencv实现计算两条直线或线段角度方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧
    2022-12-12
  • 使用Python设置PDF中图片的透明度的实现方法

    使用Python设置PDF中图片的透明度的实现方法

    在PDF文档的设计与内容创作过程中,图像的透明度设置是一个重要的操作,尤其是在处理图文密集型PDF文档时,本文将介绍如何使用Python添加指定透明度的图片到PDF文档或调整PDF文档中现有图片的透明度,需要的朋友可以参考下
    2024-09-09
  • python turtle库画一个方格和圆实例

    python turtle库画一个方格和圆实例

    在本篇文章里小编给大家分享了关于python中用turtle库画一个方格和圆实例和相关代码,需要的朋友们可以学习参考下。
    2019-06-06
  • 浅析python 定时拆分备份 nginx 日志的方法

    浅析python 定时拆分备份 nginx 日志的方法

    本文给大家分享python 定时拆分备份 nginx 日志的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-04-04
  • Python中的递归函数使用详解

    Python中的递归函数使用详解

    这篇文章主要介绍了Python中的递归函数使用详解,递归函数是指某个函数调用自己或者调用其他函数后再次调用自己,由于不能无限嵌套调用,所以某个递归函数一定存在至少两个分支,一个是退出嵌套,不再直接或者间接调用自己;另外一个则是继续嵌套,需要的朋友可以参考下
    2023-12-12
  • python 读写文件包含多种编码格式的解决方式

    python 读写文件包含多种编码格式的解决方式

    今天小编就为大家分享一篇python 读写文件包含多种编码格式的解决方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • python如何实现excel数据添加到mongodb

    python如何实现excel数据添加到mongodb

    本文介绍了python是如何实现excel数据添加到mongodb,为了将数据导入mongodb,引入了pymongo,xlrd包,需要的朋友可以参考下
    2015-07-07
  • torch.utils.data.DataLoader与迭代器转换操作

    torch.utils.data.DataLoader与迭代器转换操作

    这篇文章主要介绍了torch.utils.data.DataLoader与迭代器转换操作,文章内容接受非常详细,对正在学习或工作的你有一定的帮助,需要的朋友可以参考一下
    2022-02-02
  • Python搭建HTTP服务过程图解

    Python搭建HTTP服务过程图解

    这篇文章主要介绍了Python搭建HTTP服务过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-12-12
  • python调用有道智云API实现文件批量翻译

    python调用有道智云API实现文件批量翻译

    这篇文章主要介绍了python如何调用有道智云API实现文件批量翻译,帮助大家更好得理解和使用python,感兴趣的朋友可以了解下
    2020-10-10

最新评论