pytorch之torch_scatter.scatter_max()用法

 更新时间:2023年09月11日 11:45:10   作者:A2333fun  
这篇文章主要介绍了pytorch之torch_scatter.scatter_max()用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

  • 根据index将src分组,求每一组中的最大值输出到out
  • dim是维度

from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
'''src根据index进行分组'''
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)

输出

tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1,  3,  4,  0,  1],
        [ 1,  4,  3, -1, -1, -1]])

解释

torch_scatter.scatter()使用

1. 参数

具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如 sum mean 等,然后将这些操作结果按照索引顺序进行拼接。

下面我用具体的例子来进行讲解。

2. 示例

2.1 简单示例

首先初始化src和index:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)

接着使用scatter函数:

out = scatter(src, index, dim=0, reduce='mean')

我们观察 index=[0, 0, 1] ,第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果 index=[1, 0, 0] ,则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。

那么src中第几个元素到底是如何定义的呢?这就需要用到 dim 参数了。

dim=0 意味着我们需要对src的维度0进行操作:

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

即src中第0个元素为 [1, 2, 3] ,第1个元素为 [4, 5, 6] ,第2个元素为 [7, 8, 9]

而如果 dim=1 ,则第0个元素为 [1, 4, 7] ,第1个元素为 [2, 5, 8] ,第2个元素为 [3, 6, 9]

因此,如果有以下代码:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')

那么我们就应该将src中的第0个元素为 [1, 2, 3] 和第1个元素为 [4, 5, 6] 求平均为 [2.5, 3.5, 4.5] ,然后第2个元素 [7, 8, 9] 保持不变,即:

tensor([[2.5000, 3.5000, 4.5000],
        [7.0000, 8.0000, 9.0000]])

2.2 顺序问题

上面的例子中 index=[0, 0, 1] ,最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。

如果 index=[1, 1, 0] ,结果为:

tensor([[7.0000, 8.0000, 9.0000],
        [2.5000, 3.5000, 4.5000]])

可以发现,上述结果是将src中第2个元素 [7, 8, 9] 保持不变放到了位置0,然后将src中第0个元素 [1, 2, 3] 和第1个元素 [4, 5, 6] 求平均保持不变放到了位置1。

也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。

2.3 维度问题

如果src的维度为(4, 3),而我们需要对 dim=0 操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

报错为:

RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]

正确做法应该是:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

输出为:

tensor([[ 7.0000,  8.0000,  9.0000],
        [ 2.5000,  3.5000,  4.5000],
        [10.0000, 11.0000, 12.0000]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于Python+Matplotlib实现直方图的绘制

    基于Python+Matplotlib实现直方图的绘制

    Matplotlib是Python的绘图库,它能让使用者很轻松地将数据图形化,并且提供多样化的输出格式。本文将为大家介绍如何用matplotlib绘制直方图,感兴趣的朋友可以学习一下
    2022-04-04
  • 利用Python实现数值积分的方法

    利用Python实现数值积分的方法

    这篇文章主要介绍了利用Python实现数值积分。本文主要用于对比使用Python来实现数学中积分的几种计算方式,并和真值进行对比,加深大家对积分运算实现方式的理解
    2022-02-02
  • Python实现石头剪刀布游戏

    Python实现石头剪刀布游戏

    这篇文章主要为大家详细介绍了Python实现石头剪刀布游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-01-01
  • 使用python实现ANN

    使用python实现ANN

    这篇文章主要为大家详细介绍了使用python实现ANN的相关资料,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12
  • Python多线程中阻塞(join)与锁(Lock)使用误区解析

    Python多线程中阻塞(join)与锁(Lock)使用误区解析

    这篇文章主要为大家详细介绍了Python多线程中阻塞join与锁Lock的使用误区,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • python连接mysql有哪些方法

    python连接mysql有哪些方法

    在本篇文章里小编给大家分享的是一篇关于python连接mysql的方法,有兴趣的朋友们可以学习下。
    2020-06-06
  • pycharm 实现光标快速移动到括号外或行尾的操作

    pycharm 实现光标快速移动到括号外或行尾的操作

    这篇文章主要介绍了pycharm 实现光标快速移动到括号外或行尾的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • python 实现 redis 数据库的操作

    python 实现 redis 数据库的操作

    这篇文章主要介绍了python 包 redis 数据库的操作教程,redis 是一个 Key-Value 数据库,下文基于python的相关资料展开对redis 数据库操作的详细介绍,需要的小伙伴可以参考一下
    2022-04-04
  • python对一个数向上取整的实例方法

    python对一个数向上取整的实例方法

    在本篇文章中小编给大家整理了关于python对一个数向上取整的实例方法,需要的朋友们可以跟着学习下。
    2020-06-06
  • pyqt5 QListWidget的用法解析

    pyqt5 QListWidget的用法解析

    这篇文章主要介绍了pyqt5 QListWidget的用法解析,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03

最新评论