pytorch更新tensor中指定index位置的值scatter_add_问题

 更新时间:2023年06月14日 09:07:37   作者:腾阳山泥若  
这篇文章主要介绍了pytorch更新tensor中指定index位置的值scatter_add_问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

使用scatter_add_更新tensor张量中指定index位置的值

例子

import torch
a = torch.zeros((3, 4))
print(a)
"""
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
"""
b = torch.rand((2, 4))
print(b)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.5577],
        [0.3469, 0.1025, 0.8185, 0.5085]])
"""
# 将a中第0行和第2行的值修改为b
a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b)
print(a)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.3469, 0.1025, 0.8185, 0.0000]])
"""

torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter

torch_scatter.scatter_add

官方文档:

torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)

Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.

看着挺疑惑的,自己试了一把:

src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)

输出结果为:tensor([ 9, 97, 10])

说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。

例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97

同理:out[0]=src[8]=9     out[2]=src[0]=10

另一个函数

Tensor.scatter_add_

官方文档:

scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::
    self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] += other[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] += other[i][j][k]  # if dim == 2

官方例子:

            >>> x = torch.rand(2, 5)
            >>> x
            tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
                    [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
            >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
                    [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
                    [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。

以上面为例

index[0][0]=0  self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404
index[0][1]=1  self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427

。。。

以此类推,将index遍历一遍,就得到最终的结果

所以,self中需要改变的是index中列出的坐标,其他的是不动的。

Tensor.scatter_

scatter_(self, dim, index, src)

和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加

例子:

>>> x = torch.rand(2, 5)
            >>> x
            tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
                    [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
            >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                    [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                    [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
            >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
            >>> z
            tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
                    [ 0.0000,  0.0000,  0.0000,  1.2300]])

另外,pytorch中还有

scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python k-近邻算法实例分享

    python k-近邻算法实例分享

    这个算法主要工作是测量不同特征值之间的距离,有个这个距离,就可以进行分类了。简称kNN。
    2014-06-06
  • Python字符串模糊匹配工具TheFuzz的用法详解

    Python字符串模糊匹配工具TheFuzz的用法详解

    在处理文本数据时,常常需要进行模糊字符串匹配来找到相似的字符串,Python的TheFuzz库提供了强大的方法用于解决这类问题,本文将深入介绍TheFuzz库,探讨其基本概念、常用方法和示例代码,需要的朋友可以参考下
    2023-12-12
  • 利用Pyhton中的requests包进行网页访问测试的方法

    利用Pyhton中的requests包进行网页访问测试的方法

    今天小编就为大家分享一篇利用Pyhton中的requests包进行网页访问测试的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python学习笔记之调用eval函数出现invalid syntax错误问题

    python学习笔记之调用eval函数出现invalid syntax错误问题

    python是一门多种用途的编程语言,时常扮演脚本语言的角色。一般来说,python可以定义为面向对象的脚本语言,这个定义把面向对象的支持和面向脚本语言的角色融合在一起。很多时候,人们常常喜欢用“脚本”和不是语言来描述python的代码文件。
    2015-10-10
  • python遍历数组的方法小结

    python遍历数组的方法小结

    这篇文章主要介绍了python遍历数组的方法,实例总结了两种Python遍历数组的技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • Django中多种重定向方法使用详解

    Django中多种重定向方法使用详解

    这篇文章主要介绍了Django中多种重定向方法使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • PyTorch 使用torchvision进行图片数据增广

    PyTorch 使用torchvision进行图片数据增广

    本文主要介绍了PyTorch 使用torchvision进行图片数据增广,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-05-05
  • python 发送qq邮件的示例

    python 发送qq邮件的示例

    这篇文章主要介绍了python 发送qq邮件的示例,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-03-03
  • Python类的基本写法与注释风格介绍

    Python类的基本写法与注释风格介绍

    这篇文章主要介绍了Python类的基本写法与注释风格,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-06-06
  • pytorch简单实现神经网络功能

    pytorch简单实现神经网络功能

    这篇文章主要介绍了pytorch简单实现神经网络,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-09-09

最新评论