pytorch更新tensor中指定index位置的值scatter_add_问题

 更新时间:2023年06月14日 09:07:37   作者:腾阳山泥若  
这篇文章主要介绍了pytorch更新tensor中指定index位置的值scatter_add_问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

使用scatter_add_更新tensor张量中指定index位置的值

例子

import torch
a = torch.zeros((3, 4))
print(a)
"""
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
"""
b = torch.rand((2, 4))
print(b)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.5577],
        [0.3469, 0.1025, 0.8185, 0.5085]])
"""
# 将a中第0行和第2行的值修改为b
a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b)
print(a)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.3469, 0.1025, 0.8185, 0.0000]])
"""

torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter

torch_scatter.scatter_add

官方文档:

torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)

Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.

看着挺疑惑的,自己试了一把:

src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)

输出结果为:tensor([ 9, 97, 10])

说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。

例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97

同理:out[0]=src[8]=9     out[2]=src[0]=10

另一个函数

Tensor.scatter_add_

官方文档:

scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::
    self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] += other[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] += other[i][j][k]  # if dim == 2

官方例子:

            >>> x = torch.rand(2, 5)
            >>> x
            tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
                    [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
            >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
                    [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
                    [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。

以上面为例

index[0][0]=0  self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404
index[0][1]=1  self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427

。。。

以此类推,将index遍历一遍,就得到最终的结果

所以,self中需要改变的是index中列出的坐标,其他的是不动的。

Tensor.scatter_

scatter_(self, dim, index, src)

和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加

例子:

>>> x = torch.rand(2, 5)
            >>> x
            tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
                    [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
            >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                    [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                    [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
            >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
            >>> z
            tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
                    [ 0.0000,  0.0000,  0.0000,  1.2300]])

另外,pytorch中还有

scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 对python中执行DOS命令的3种方法总结

    对python中执行DOS命令的3种方法总结

    今天小编就为大家分享一篇对python中执行DOS命令的3种方法总结,具有很好的参考价值,希望对大家有所帮助一起。一起跟随小编过来看看吧
    2018-05-05
  • python3+PyQt5图形项的自定义和交互 python3实现page Designer应用程序

    python3+PyQt5图形项的自定义和交互 python3实现page Designer应用程序

    这篇文章主要为大家详细介绍了python3+PyQt5图形项的自定义和交互,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-04-04
  • python处理常见格式压缩包文件的全指南

    python处理常见格式压缩包文件的全指南

    这篇文章主要为大家详细介绍了如何使用python处理常见格式压缩包文件,例如7z压缩包,tar和gz压缩包,zip类压缩包和.rar文件,有需要的小伙伴可以了解下
    2025-05-05
  • python 通过SMSActivateAPI 获取验证码的步骤

    python 通过SMSActivateAPI 获取验证码的步骤

    这篇文章主要介绍了python 通过SMSActivateAPI 如何获取验证码,本文分步骤给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-05-05
  • django 链接多个数据库 并使用原生sql实现

    django 链接多个数据库 并使用原生sql实现

    这篇文章主要介绍了django 链接多个数据库 并使用原生sql实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 使用Python实现在PowerPoint中创建和定制SmartArt图形

    使用Python实现在PowerPoint中创建和定制SmartArt图形

    在现代商务演示中,SmartArt 图形是一种强大的可视化工具,本文将介绍如何使用 Python 在 PowerPoint 演示文稿中创建和定制 SmartArt 图形,实现自动化的专业演示文档生成,希望对大家有所帮助
    2026-05-05
  • 详解python中的模块及包导入

    详解python中的模块及包导入

    python中的导入关键字:import 以及from import。这篇文章主要介绍了详解python中的模块及包导入,需要的朋友可以参考下
    2019-08-08
  • 利用python实现轻松抓取Google搜索数据

    利用python实现轻松抓取Google搜索数据

    从谷歌抓取数据是一个复杂且需要谨慎处理的任务,因为谷歌有非常严格的反自动化和反爬虫机制,本文将分享两个最有效且合规的Python方法,需要的可以了解下
    2025-08-08
  • python中pandas操作apply返回多列的实现

    python中pandas操作apply返回多列的实现

    本文主要介绍了python中pandas操作apply返回多列的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-08-08
  • Django中数据库迁移常用的命令小结

    Django中数据库迁移常用的命令小结

    在Django中数据库迁移用于保持数据库结构与模型定义同步,这篇文章主要介绍了Django中数据库迁移常用的命令,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2025-03-03

最新评论