pytorch torch.gather函数的使用

 更新时间:2024年09月09日 10:23:22   作者:qq_27390023  
torch.gather 是 PyTorch 中用于在指定维度上通过索引从源张量中提取元素的函数,它需要输入张量、维度索引和索引张量,示例代码展示了如何使用 torch.gather 从输入张量中按索引提取元素,返回的结果张量形状与索引张量相同

pytorch torch.gather函数

torch.gather 是 PyTorch 中的一个用于从给定维度上按索引取值的函数。

它根据一个索引张量 index,从源张量 input 中收集值,并返回一个新的张量。

torch.gather 常用于需要从张量的特定位置抽取元素的操作。

1. 函数签名

torch.gather(input, dim, index, *, sparse_grad=False, out=None)
  • input:输入张量,表示要从中收集元素的源张量。
  • dim:要收集的维度索引。例如,对于一个二维张量,0 表示沿着行的维度,1 表示沿着列的维度。
  • index:索引张量,其形状应与input张量在除了dim维度之外的其他维度上保持一致。索引张量中的值表示在input张量对应维度上要收集的元素的索引。
  • out(可选):输出张量,如果提供,结果将存储在这个张量中。

2. 工作原理

torch.gatherdim 维度上,通过 index 指定的索引,从 input 中选取元素。

返回的张量的形状与 index 的形状相同。

3. 示例代码

以下是一个简单的示例代码,演示如何使用 torch.gather 函数:

import torch

# 创建一个源张量
input = torch.tensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# 创建一个索引张量
index = torch.tensor([[0, 2, 1],
                      [2, 0, 1],
                      [1, 2, 0]])

# 在 dim=1 维度上使用 gather 函数
result = torch.gather(input, dim=1, index=index)

print("Input Tensor:")
print(input)
print("\nIndex Tensor:")
print(index)
print("\nResult Tensor:")
print(result)

4. 输出结果

Input Tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

Index Tensor:
tensor([[0, 2, 1],
        [2, 0, 1],
        [1, 2, 0]])

Result Tensor:
tensor([[1, 3, 2],
        [6, 4, 5],
        [8, 9, 7]])

5. 解释

  • 输入张量 (input) 是一个 3x3 的矩阵,每个元素代表一个值。
  • 索引张量 (index) 指定了要从 input 中提取的元素的索引。
  • 结果张量 (result) 是根据 indexinput 中提取的元素形成的张量。

在这个例子中:

  • 对于 input 的第一行,index 提取了索引 0, 2, 1 对应的元素 1, 3, 2
  • 对于 input 的第二行,index 提取了索引 2, 0, 1 对应的元素 6, 4, 5
  • 对于 input 的第三行,index 提取了索引 1, 2, 0 对应的元素 8, 9, 7

总结

torch.gather 通过索引在指定维度上提取张量中的元素,是用于基于索引选择数据的有用工具。

函数对批处理数据特别有用,例如在分类任务中提取对应类别的概率或得分。

索引张量的形状必须与源张量在指定维度的形状相匹配,以确保正确的取值操作。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中的3D绘图命令总结

    Python中的3D绘图命令总结

    很多情况下,为了能够观察到数据之间的内部的关系,可以使用绘图来更好的显示规律。而Python的matplotlib库中有很多三维图表显示的命令,本文为大家做了一个总结,需要的可以参考一下
    2022-02-02
  • pandas筛选某列出现编码错误的解决方法

    pandas筛选某列出现编码错误的解决方法

    今天小编就为大家分享一篇pandas筛选某列出现编码错误的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 使用Python Tkinter创建一个动态祝福弹窗的详细教程

    使用Python Tkinter创建一个动态祝福弹窗的详细教程

    本文手把手教你用Python的Tkinter库创建一个浪漫的弹窗程序,包含淡入淡出动画、多线程管理、队列控制等高级特性,通过完整的代码解析和配置指南,带你掌握GUI编程的核心技巧,需要的朋友可以参考下
    2025-11-11
  • Python中schedule模块关于定时任务使用方法

    Python中schedule模块关于定时任务使用方法

    这篇文章主要介绍了Python中schedule模块关于定时任务使用方法,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-05-05
  • 在Python中使用dict和set方法的教程

    在Python中使用dict和set方法的教程

    这篇文章主要介绍了在Python中使用dict和set方法的教程,dict字典是Python中的重要基础知识,set与其类似,需要的朋友可以参考下
    2015-04-04
  • Python 对象序列化与反序列化之pickle json详细解析

    Python 对象序列化与反序列化之pickle json详细解析

    我们知道在Python中,一切皆为对象,实例是对象,类是对象,元类也是对象。本文正是要聊聊如何将这些对象有效地保存起来,以供后续使用
    2021-09-09
  • 详解Python正则表达式re模块

    详解Python正则表达式re模块

    这篇文章主要介绍了Python正则表达式re模块,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-03-03
  • Python实现简单石头剪刀布小游戏的示例代码

    Python实现简单石头剪刀布小游戏的示例代码

    石头剪刀布是一种简单而又经典的游戏,常常用于决定胜负或者娱乐消遣,本文将使用Python实现一个简单的石头剪刀布游戏,需要的可以参考一下
    2023-06-06
  • 在Python中合并字典模块ChainMap的隐藏坑【推荐】

    在Python中合并字典模块ChainMap的隐藏坑【推荐】

    在Python中,当我们有两个字典需要合并的时候,可以使用字典的 update 方法,接下来通过本文给大家介绍在Python中合并字典模块ChainMap的隐藏坑,感兴趣的朋友一起看看吧
    2019-06-06
  • Python使用scrapy爬取阳光热线问政平台过程解析

    Python使用scrapy爬取阳光热线问政平台过程解析

    这篇文章主要介绍了Python使用scrapy爬取阳光热线问政平台过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08

最新评论