pytorch更新tensor中指定index位置的值scatter_add_问题
使用scatter_add_更新tensor张量中指定index位置的值
例子
import torch a = torch.zeros((3, 4)) print(a) """ tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) """ b = torch.rand((2, 4)) print(b) """ tensor([[0.6293, 0.3050, 0.9608, 0.5577], [0.3469, 0.1025, 0.8185, 0.5085]]) """ # 将a中第0行和第2行的值修改为b a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b) print(a) """ tensor([[0.6293, 0.3050, 0.9608, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.3469, 0.1025, 0.8185, 0.0000]]) """
torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter
torch_scatter.scatter_add
官方文档:
torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)
Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.
看着挺疑惑的,自己试了一把:
src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9]) index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0]) out=scatter_add(src, index) print(out)
输出结果为:tensor([ 9, 97, 10])
说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。
例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97
同理:out[0]=src[8]=9 out[2]=src[0]=10
另一个函数
Tensor.scatter_add_
官方文档:
scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as:: self[index[i][j][k]][j][k] += other[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] += other[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] += other[i][j][k] # if dim == 2
官方例子:
>>> x = torch.rand(2, 5) >>> x tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328], [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]]) >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328], [1.0000, 1.0427, 1.0000, 1.6782, 1.0000], [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。
以上面为例
index[0][0]=0 self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404 index[0][1]=1 self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427
。。。
以此类推,将index遍历一遍,就得到最终的结果
所以,self中需要改变的是index中列出的坐标,其他的是不动的。
Tensor.scatter_
scatter_(self, dim, index, src)
和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加
例子:
>>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]) >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23) >>> z tensor([[ 0.0000, 0.0000, 1.2300, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.2300]])
另外,pytorch中还有
scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
python3+PyQt5图形项的自定义和交互 python3实现page Designer应用程序
这篇文章主要为大家详细介绍了python3+PyQt5图形项的自定义和交互,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下2018-04-04
python 通过SMSActivateAPI 获取验证码的步骤
这篇文章主要介绍了python 通过SMSActivateAPI 如何获取验证码,本文分步骤给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2023-05-05
使用Python实现在PowerPoint中创建和定制SmartArt图形
在现代商务演示中,SmartArt 图形是一种强大的可视化工具,本文将介绍如何使用 Python 在 PowerPoint 演示文稿中创建和定制 SmartArt 图形,实现自动化的专业演示文档生成,希望对大家有所帮助2026-05-05


最新评论