pytorch中的scatter_add_函数的使用解读

 更新时间:2023年06月14日 08:59:37   作者:*Lisen  
这篇文章主要介绍了pytorch中的scatter_add_函数的使用解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch scatter_add_函数的使用

关于这个函数,很少博客详细的介绍。下面就我个人理解简单介绍下。

函数:

self_tensor.scatter_add_(dim, index_tensor, other_tensor) → 输出tensor

该函数的意思是:

将other_tensor中的数据,按照index_tensor中的索引位置,添加至self_tensor矩阵中。

参数:

  • dim:表示需要改变的维度,但是注意,假如dim=1,并不是说self_tensor在dim=0上的数据不会改变,这个dim只是在取矩阵数据时不固定dim=1维的索引,使用index_tensor矩阵中的索引。可能这样说还是不太理解,下面会用例子说明。其中self_tensor表示我们需要改变的tensor矩阵
  • index_tensor:索引矩阵;
  • other_tensor:需要添加到self_tensor中的tensor

要求:

1、self_tensor,index_tensor, other_tensor 的维度需要相同,即self.tensor.dim() = index_tensor.dim() = other_tensor.dim();

2、假设dim=d,那么index_tensor矩阵中的所有数据必须小于d-1;

3、假设dim=d,index_tensor矩阵在d维度上的size必须小于self_tensor和other_tensor的size;即index.size(d) <= other_tensor.size(d) 且index.size(d) <= self_tensor.size(d)

三维计算公式:

self[index[i][j][k]][j][k] += other[i][j][k] # 如果 dim == 0
self[i][index[i][j][k]][k] += other[i][j][k] # 如果 dim == 1
self[i][j][index[i][j][k]] += other[i][j][k] # 如果 dim == 2

二维计算公式:

self[index[i][j]][j] += other[i][j] # 如果 dim == 0
self[i][index[i][j]] += other[i][j] # 如果 dim == 1
  index_tensor = torch.tensor([[0,1],[1,1]])
  print('index_tensor: \n', index_tensor)
  self_tensor = torch.arange(0, 4).view(2, 2)
  print('self_tensor: \n', self_tensor)
  other_tensor = torch.arange(5, 9).view(2, 2)
  print('other_tensor: \n', other_tensor)
  dim = 0
  for i in range(index_tensor.size(0)):
      for j in range(index_tensor.size(1)):
          replace_index = index_tensor[i][j]
          if dim == 0:
              # self矩阵的第0维索引
              self_tensor[replace_index][j] += other_tensor[i][j]
          elif dim == 1:
              # self矩阵的第1维索引
              self_tensor[i][replace_index] += other_tensor[i][j]       
  print(self_tensor)

结果:

    index_tensor: 
 tensor([[0, 1],
        [1, 1]])
self_tensor: 
 tensor([[0, 1],
        [2, 3]])
other_tensor: 
 tensor([[5, 6],
        [7, 8]])
tensor([[ 5,  1],
        [ 9, 17]])

使用函数计算:

index_tensor = torch.tensor([[0,1],[1,1]])
print('index_tensor: \n', index_tensor)
self_tensor = torch.arange(0, 4).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
self_tensor.scatter_add_(0, index_tensor, other_tensor) 
print(self_tensor)

结果:

index_tensor: 
 tensor([[0, 1],
        [1, 1]])
self_tensor: 
 tensor([[0, 1],
        [2, 3]])
other_tensor: 
 tensor([[5, 6],
        [7, 8]])
tensor([[ 5,  1],
        [ 9, 17]])

scatter_add()函数通俗理解

self [ index[i,j] , j ] += src [ i , j ] # if dim == 0
self [ i , index[i,j] ] += src[ i, j ] # if dim == 1

理解scatter_add()函数,看index就行了,index有多少个,self坐标就会变多少次。

self是一个二维的数组,self[第一维,第二维],dim==0,就是将src对应坐标,对应到 index 坐标里面的值,放置到self的第一维中。

例如:

src[i,j]对应到index[i,j],假设 index[i,j] ==0,则self[第一维,第二维] 为self[0,j],只改变第一维,第二维的值和src第二维一样。

然后self[0,j]的值就会变为 self[0,j]=self[0,j]+src[i,j]

代码中 self=torch.zeros(3,5), dim=0, index=[0,1,2,0,0], src=torch.ones(2,5)

我们只看 src,当 src[0,0]=1, index[0,0]=0, self[0,0]=self[0,0]+src[0,0]=1。

当src[0,1]=1, index[0,1]=1, self[1,1]=self[1,1]+src[0,1]=1, self的第一维是index的值决定的为1,第二维是src的第二维坐标决定也为1。

当index的值没有时,就停止变换,self没有变换过的坐标值就保持不变。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python3.10中match-case的用法和示例详解

    Python3.10中match-case的用法和示例详解

    在 Python 3.10 中引入了新的 match-case 语法,它是一种用于模式匹配的结构,下面小编就来和大家简单聊聊match-case的用法和示例吧,有需要的小伙伴可以参考下
    2023-10-10
  • python numpy库np.percentile用法说明

    python numpy库np.percentile用法说明

    这篇文章主要介绍了python numpy库np.percentile用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • OpenCV-Python实现轮廓拟合

    OpenCV-Python实现轮廓拟合

    本文将结合实例代码,介绍OpenCV-Python实现轮廓拟合,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-06-06
  • 基于pytorch padding=SAME的解决方式

    基于pytorch padding=SAME的解决方式

    今天小编就为大家分享一篇基于pytorch padding=SAME的解决方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • 基于Python实现简易文档格式转换器

    基于Python实现简易文档格式转换器

    这篇文章主要介绍了基于Python和PyQT5实现简易的文档格式转换器,支持.txt/.xlsx/.csv格式的转换。感兴趣的小伙伴可以跟随小编一起学习一下
    2021-12-12
  • Python3 shelve对象持久存储原理详解

    Python3 shelve对象持久存储原理详解

    这篇文章主要介绍了Python3 shelve对象持久存储原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • opencv实现图像缩放效果

    opencv实现图像缩放效果

    这篇文章主要为大家详细介绍了opencv实现图像缩放效果,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-03-03
  • Python对象与json数据的转换问题实例详解

    Python对象与json数据的转换问题实例详解

    JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,很受广大用户喜爱,今天通过本文给大家介绍Python对象与json数据的转换问题,需要的朋友可以参考下
    2022-07-07
  • python使用opencv实现马赛克效果示例

    python使用opencv实现马赛克效果示例

    这篇文章主要介绍了python使用opencv实现马赛克效果,结合实例形式分析了Python使用cv2模块操作图片实现马赛克效果的相关技巧,需要的朋友可以参考下
    2019-09-09
  • WIndows10系统下面安装Anaconda、Pycharm及Pytorch环境全过程(NVIDIA GPU版本)

    WIndows10系统下面安装Anaconda、Pycharm及Pytorch环境全过程(NVIDIA GPU版本)

    这篇文章主要给大家介绍了关于WIndows10系统下面安装Anaconda、Pycharm及Pytorch环境(NVIDIA GPU版本)的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2023-02-02

最新评论