pytorch tensor按广播赋值scatter_函数的用法

 更新时间:2023年06月14日 08:44:50   作者:城俊BLOG  
这篇文章主要介绍了pytorch tensor按广播赋值scatter_函数的用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch tensor按广播赋值scatter函数

普通广播

>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
# 和a shape相同,但是用0填充
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
# 赋值索引
>>> c[:,0]
tensor([0, 1])
# 赋值语句:使用广播机制进行赋值
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
        [0, 1, 0]])

为什么会出现这样的结果?

赋值语句的意思是:

  • 1.range(n)表示对b的所有行进行赋值操作
  • 2.c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
  • 3.最右边的1表示对应索引位置所赋的值

scatter函数

import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1':需要赋值的维度,是label的维度;‘torch.LongTensor(b)':需要赋值的索引;‘a':要赋的值
print("new_label: ",label)
label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

举例

>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
        [0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
        [1, 0]])
# 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值
# c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了)
# c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了)
# 上面的这两个赋值操作其实有重复的、逆序的
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
        [1, 1, 0]])

scatter()和scatter_()的作用和区别

scatter和scatter_函数原型如下

Tensor.scatter_(dim, index, src, reduce=None)->Tensor
scatter(input, dim, index, src)->Tensor

函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。

  • dim(int) - 操作的维度
  • index(LongTensor) - 填充依据的索引,
  • src(Tensor of float) - 操作的src数据
  • reduce(str, optional) - reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换 dim(int)

在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。

对于一个三维张量,self更新结果如下

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

使用示例

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

scatter的应用, one-hot编码

def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch), output shape: (batch, n_class)
    x=x.long()
    res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量
    res.scatter_(1, x.view(-1,1), 1) 
    # scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中
    return res
x=torch.tensor([5,7,0])
one_hot(x, 10)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Django使用消息提示简单的弹出个对话框实例

    Django使用消息提示简单的弹出个对话框实例

    今天小编就为大家分享一篇Django使用消息提示简单的弹出个对话框实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • Python 中数组和数字相乘时的注意事项说明

    Python 中数组和数字相乘时的注意事项说明

    这篇文章主要介绍了Python 中数组和数字相乘时的注意事项说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-05-05
  • 20个被低估的Python性能优化技巧分享

    20个被低估的Python性能优化技巧分享

    这篇文章主要为大家详细介绍了20个被低估的Python性能优化技巧并附上了实测数据,文中的示例代码简洁易懂,有需要的小伙伴可以参考一下
    2025-03-03
  • Python word2vec训练词向量实例分析讲解

    Python word2vec训练词向量实例分析讲解

    这篇文章主要介绍了Python word2vec训练词向量实例分析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧
    2022-12-12
  • 关于np.meshgrid函数中的indexing参数问题

    关于np.meshgrid函数中的indexing参数问题

    Meshgrid函数在二维与三维空间中用于生成坐标网格,便于进行图像处理和空间数据分析,二维情况下,默认使用笛卡尔坐标系,而三维meshgrid则涉及不同的坐标轴取法,在三维情况下,可能会出现坐标轴排列序混乱
    2024-09-09
  • python基础知识(一)变量与简单数据类型详解

    python基础知识(一)变量与简单数据类型详解

    这篇文章主要介绍了python变量与简单数据类型详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-04-04
  • python小练习之爬鱿鱼游戏的评价生成词云

    python小练习之爬鱿鱼游戏的评价生成词云

    读万卷书不如行万里路,只学书上的理论是远远不够的,只有在实战中才能获得能力的提升,本篇文章手把手带你用Python爬取热火的鱿鱼游戏评价,大家可以在过程中查缺补漏,提升水平
    2021-10-10
  • 详解如何使用pip卸载所有已安装的Python包

    详解如何使用pip卸载所有已安装的Python包

    在开发过程中,我们可能会安装许多 Python 包,有时需要彻底清理环境,以便从头开始或者解决冲突问题,下面将介绍如何使用 pip 命令卸载所有已安装的 Python 包,需要的朋友可以参考下
    2024-06-06
  • 对Matlab中共轭、转置和共轭装置的区别说明

    对Matlab中共轭、转置和共轭装置的区别说明

    这篇文章主要介绍了对Matlab中共轭、转置和共轭装置的区别说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • 使用Python画了一棵圣诞树的实例代码

    使用Python画了一棵圣诞树的实例代码

    这篇文章主要介绍了使用Python画了一棵圣诞树的实例代码,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11

最新评论