pytorch更新tensor中指定index位置的值scatter_add_问题

 更新时间:2023年06月14日 09:07:37   作者:腾阳山泥若  
这篇文章主要介绍了pytorch更新tensor中指定index位置的值scatter_add_问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

使用scatter_add_更新tensor张量中指定index位置的值

例子

import torch
a = torch.zeros((3, 4))
print(a)
"""
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
"""
b = torch.rand((2, 4))
print(b)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.5577],
        [0.3469, 0.1025, 0.8185, 0.5085]])
"""
# 将a中第0行和第2行的值修改为b
a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b)
print(a)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.3469, 0.1025, 0.8185, 0.0000]])
"""

torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter

torch_scatter.scatter_add

官方文档:

torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)

Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.

看着挺疑惑的,自己试了一把:

src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)

输出结果为:tensor([ 9, 97, 10])

说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。

例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97

同理:out[0]=src[8]=9     out[2]=src[0]=10

另一个函数

Tensor.scatter_add_

官方文档:

scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::
    self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] += other[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] += other[i][j][k]  # if dim == 2

官方例子:

            >>> x = torch.rand(2, 5)
            >>> x
            tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
                    [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
            >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
                    [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
                    [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。

以上面为例

index[0][0]=0  self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404
index[0][1]=1  self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427

。。。

以此类推,将index遍历一遍,就得到最终的结果

所以,self中需要改变的是index中列出的坐标,其他的是不动的。

Tensor.scatter_

scatter_(self, dim, index, src)

和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加

例子:

>>> x = torch.rand(2, 5)
            >>> x
            tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
                    [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
            >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                    [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                    [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
            >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
            >>> z
            tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
                    [ 0.0000,  0.0000,  0.0000,  1.2300]])

另外,pytorch中还有

scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python迭代器的实现原理

    Python迭代器的实现原理

    这篇文章主要介绍了Python迭代器的实现原理,文章基于python的相关资料展开对Python迭代器的详细介绍,需要的小伙伴可以参考一下
    2022-05-05
  • python 动态迁移solr数据过程解析

    python 动态迁移solr数据过程解析

    这篇文章主要介绍了python 动态迁移solr数据过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 利用OpenCV和Python实现查找图片差异

    利用OpenCV和Python实现查找图片差异

    今天小编就为大家分享一篇利用OpenCV和Python实现查找图片差异,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python gRPC流式通信协议详细讲解

    Python gRPC流式通信协议详细讲解

    这篇文章主要介绍了Python gRPC流式通信协议,最近几天在搞golang的grpc,跑通之后想用php作为客户端调用一下grpc服务,结果拉了,一个php的grpc服务安装,搞了好几天,总算搞定了
    2022-11-11
  • Pandas条件筛选与组合筛选的使用

    Pandas条件筛选与组合筛选的使用

    本文主要介绍了Pandas条件筛选与组合筛选的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • python pygame实现五子棋双人联机

    python pygame实现五子棋双人联机

    这篇文章主要为大家详细介绍了python pygame实现五子棋双人联机,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-05-05
  • Pytorch GPU内存占用很高,但是利用率很低如何解决

    Pytorch GPU内存占用很高,但是利用率很低如何解决

    这篇文章主要介绍了Pytorch GPU内存占用很高,但是利用率很低的原因及解决方法,具有很好的参考价值,希望对大家 有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • 如何在Python中编写接口和请求外部接口

    如何在Python中编写接口和请求外部接口

    这篇文章主要介绍了如何在Python中编写接口和请求外部接口,requests库来请求外部接口,按照请求方法分为get请求和post请求,下面和小编一起进入文章了解更多的具体内容吧
    2022-02-02
  • Numpy中stack(),hstack(),vstack()函数用法介绍及实例

    Numpy中stack(),hstack(),vstack()函数用法介绍及实例

    这篇文章主要介绍了Numpy中stack(),hstack(),vstack()函数用法介绍及实例,具有一定借鉴价值,需要的朋友可以参考下
    2018-01-01
  • 关于matlab图像滤波详解(二维傅里叶滤波)

    关于matlab图像滤波详解(二维傅里叶滤波)

    这篇文章主要介绍了关于matlab图像滤波详解(二维傅里叶滤波),具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02

最新评论