pytorch tensor按广播赋值scatter_函数的用法

 更新时间:2023年06月14日 08:44:50   作者:城俊BLOG  
这篇文章主要介绍了pytorch tensor按广播赋值scatter_函数的用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch tensor按广播赋值scatter函数

普通广播

>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
# 和a shape相同,但是用0填充
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
# 赋值索引
>>> c[:,0]
tensor([0, 1])
# 赋值语句:使用广播机制进行赋值
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
        [0, 1, 0]])

为什么会出现这样的结果?

赋值语句的意思是:

  • 1.range(n)表示对b的所有行进行赋值操作
  • 2.c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)
  • 3.最右边的1表示对应索引位置所赋的值

scatter函数

import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1':需要赋值的维度,是label的维度;‘torch.LongTensor(b)':需要赋值的索引;‘a':要赋的值
print("new_label: ",label)
label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

举例

>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
        [0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
        [1, 0]])
# 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值
# c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了)
# c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了)
# 上面的这两个赋值操作其实有重复的、逆序的
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
        [1, 1, 0]])

scatter()和scatter_()的作用和区别

scatter和scatter_函数原型如下

Tensor.scatter_(dim, index, src, reduce=None)->Tensor
scatter(input, dim, index, src)->Tensor

函数作用是将src中的数据按照dim中指定的维度和index中的索引写入self中。

  • dim(int) - 操作的维度
  • index(LongTensor) - 填充依据的索引,
  • src(Tensor of float) - 操作的src数据
  • reduce(str, optional) - reduce选择运算方式,有’add’和’mutiply’方式, 默认为替换 dim(int)

在scatter中self指返回的tensor,scatter_中self指输入的tensor自身。

对于一个三维张量,self更新结果如下

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

使用示例

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim=0, 说明按照行赋值,index[0][1]=1, 代表更改input中的第1行,src[0][1]=2,因此更改input中[1][1]中的元素为2

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

dim,说明按照列赋值,index[0][1]=1, 代表更改input中的第1列,src[0][1]=2, 更改input中[0][1]元素为2

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

scatter的应用, one-hot编码

def one_hot(x, n_class, dtype=torch.float32):
    # X shape: (batch), output shape: (batch, n_class)
    x=x.long()
    res=torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) # shape为[batch, n_class]全零向量
    res.scatter_(1, x.view(-1,1), 1) 
    # scatter_(input, dim, index, src)将src中数据根据index的索引按照dim的方向填进input中
    return res
x=torch.tensor([5,7,0])
one_hot(x, 10)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中Timedelta转换为Int或Float方式

    Python中Timedelta转换为Int或Float方式

    这篇文章主要介绍了Python中Timedelta转换为Int或Float方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07
  • python调用C/C++动态库的实践案例

    python调用C/C++动态库的实践案例

    python是动态语言,c++是静态语言,下面这篇文章主要介绍了python调用C/C++动态库的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2025-09-09
  • 机器学习Erdos Renyi随机图生成方法及特性

    机器学习Erdos Renyi随机图生成方法及特性

    这篇文章主要为大家介绍了机器学习Erdos Renyi随机图生成方法及特性详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-05-05
  • Python脚本如何在bilibili中查找弹幕发送者

    Python脚本如何在bilibili中查找弹幕发送者

    这篇文章主要介绍了如何在bilibili中查找弹幕发送者,本文给大家分享小编写的一个python脚本来实现bilibili弹幕发送者,需要的朋友可以参考下
    2020-06-06
  • 彻底理解Python list切片原理

    彻底理解Python list切片原理

    本篇文章主要介绍了Python list切片原理,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-10-10
  • 详解Python 数据库 (sqlite3)应用

    详解Python 数据库 (sqlite3)应用

    本篇文章主要介绍了Python标准库14 数据库 (sqlite3),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧。
    2016-12-12
  • 深入讲解Python函数中参数的使用及默认参数的陷阱

    深入讲解Python函数中参数的使用及默认参数的陷阱

    这篇文章主要介绍了Python函数中参数的使用及默认参数的陷阱,文中将函数的参数分为必选参数、默认参数、可变参数和关键字参数来讲,要的朋友可以参考下
    2016-03-03
  • Python中模拟enum枚举类型的5种方法分享

    Python中模拟enum枚举类型的5种方法分享

    这篇文章主要介绍了Python中模拟enum枚举类型的5种方法分享,本文直接给出实现代码,需要的朋友可以参考下
    2014-11-11
  • python爬虫爬取监控教务系统的思路详解

    python爬虫爬取监控教务系统的思路详解

    这篇文章主要介绍了python爬虫监控教务系统,主要实现思路是对已有的成绩进行处理,变为list集合,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2020-01-01
  • Python使用OpenCV实现物体跟踪的实践方法

    Python使用OpenCV实现物体跟踪的实践方法

    文章介绍了使用OpenCV进行物体跟踪的小项目,涵盖选择ROI、初始化跟踪器、逐帧更新等核心流程,代码示例展示了不同OpenCV版本的兼容性处理、跟踪失败提示和交互控制,此外,还提供了常见问题解答和进一步学习建议,需要的朋友可以参考下
    2026-04-04

最新评论