pytorch tensor按广播赋值scatter_函数的用法

 更新时间:2023年06月14日 08:44:50   作者:城俊BLOG  
这篇文章主要介绍了pytorch tensor按广播赋值scatter_函数的用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch tensor按广播赋值scatter函数

普通广播

>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
# 和a shape相同,但是用0填充
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
# 赋值索引
>>> c[:,0]
tensor([0, 1])
# 赋值语句:使用广播机制进行赋值
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
        [0, 1, 0]])

为什么会出现这样的结果?

赋值语句的意思是:

  • 1.range(n)表示对b的所有行进行赋值操作
  • 2.c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
  • 3.最右边的1表示对应索引位置所赋的值

scatter函数

import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1':需要赋值的维度,是label的维度;‘torch.LongTensor(b)':需要赋值的索引;‘a':要赋的值
print("new_label: ",label)
label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

举例

>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
        [0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
        [1, 0]])
# 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值
# c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了)
# c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了)
# 上面的这两个赋值操作其实有重复的、逆序的
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
        [1, 1, 0]])

scatter()和scatter_()的作用和区别

scatter和scatter_函数原型如下

Tensor.scatter_(dim, index, src, reduce=None)->Tensor
scatter(input, dim, index, src)->Tensor

函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。

  • dim(int) - 操作的维度
  • index(LongTensor) - 填充依据的索引,
  • src(Tensor of float) - 操作的src数据
  • reduce(str, optional) - reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换 dim(int)

在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。

对于一个三维张量,self更新结果如下

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

使用示例

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

scatter的应用, one-hot编码

def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch), output shape: (batch, n_class)
    x=x.long()
    res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量
    res.scatter_(1, x.view(-1,1), 1) 
    # scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中
    return res
x=torch.tensor([5,7,0])
one_hot(x, 10)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Scrapy使用的基本流程与实例讲解

    Scrapy使用的基本流程与实例讲解

    今天小编就为大家分享一篇关于Scrapy使用的基本流程与实例讲解,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-10-10
  • 利用Python实现在同一网络中的本地文件共享方法

    利用Python实现在同一网络中的本地文件共享方法

    今天小编就为大家分享一篇利用Python实现在同一网络中的本地文件共享方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python多线程和多进程关系详解

    python多线程和多进程关系详解

    在本篇文章里小编给大家整理的是一篇关于python多线程和多进程之间的联系的基础内容,有兴趣的朋友们可以学习下。
    2020-12-12
  • Python中Class类用法实例分析

    Python中Class类用法实例分析

    这篇文章主要介绍了Python中Class类用法,以实例形式较为详细的分析了Python中类的定义及相关使用技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-11-11
  • 在Python中用GDAL实现矢量对栅格的切割实例

    在Python中用GDAL实现矢量对栅格的切割实例

    这篇文章主要介绍了在Python中用GDAL实现矢量对栅格的切割实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 在Python的Flask框架中实现全文搜索功能

    在Python的Flask框架中实现全文搜索功能

    这篇文章主要介绍了在Python的Flask框架中实现全文搜索功能,这个基本的web功能实现起来非常简单,需要的朋友可以参考下
    2015-04-04
  • Python正则表达式总结分享

    Python正则表达式总结分享

    这篇文章主要介绍了Python正则表达式总结分享,包括正则表达式基础以及Python正则表达式标准库的完整介绍及使用示例,需要的朋友可以参考一下
    2022-03-03
  • python  UPX is not available问题解决方法

    python  UPX is not available问题解决方法

    这篇文章主要介绍了python UPX is not available问题解决,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-04-04
  • 详解pandas绘制矩阵散点图(scatter_matrix)的方法

    详解pandas绘制矩阵散点图(scatter_matrix)的方法

    这篇文章主要介绍了详解pandas绘制矩阵散点图(scatter_matrix)的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • python控制台显示时钟的示例

    python控制台显示时钟的示例

    这篇文章主要介绍了python控制台显示时钟的示例,需要的朋友可以参考下
    2014-02-02

最新评论