解析Pytorch中的torch.gather()函数

 更新时间:2021年11月13日 13:58:40   作者:xiaoliujun1999  
本文给大家介绍了Pytorch中的torch.gather()函数,通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧

参数说明

以官方说明为例,gather()函数需要三个参数,输入input,维度dim,以及索引index

input必须为Tensor类型

dim为int类型,代表从哪个维度进行索引

index为LongTensor类型

举例说明

input=torch.tensor([[1,2,3],[4,5,6]]) #作为输入
 
index1=torch.tensor([[0,1,1],[0,1,1]]) #作为索引矩阵
 
# dim=0时,按列进行索引
print (torch.gather(input,dim=0,index=index1))
 
# dim=1时,按行进行索引
print (torch.gather(input,dim=1,index=index1))

 结果如下图所示:

# 按列进行索引
tensor([[1, 5, 6],
        [4, 2, 6]])
 
# 按行进行索引
tensor([[1, 2, 2],
        [5, 4, 5]])

画图说明 

官方文档

def gather(self, input, dim, index, *args, **kwargs): 
        
        For a 3-D tensor the output is specified by::
        
            out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
            out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
            out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2        
 
        Args:
            input (Tensor): the source tensor
            dim (int): the axis along which to index
            index (LongTensor): the indices of elements to gather     
      
        Example::
        
            >>> t = torch.tensor([[1, 2], [3, 4]])
            >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
            tensor([[ 1,  1],
                    [ 4,  3]])

到此这篇关于Pytorch中的torch.gather()函数的文章就介绍到这了,更多相关Pytorch torch.gather()函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 详解如何用TensorFlow训练和识别/分类自定义图片

    详解如何用TensorFlow训练和识别/分类自定义图片

    这篇文章主要介绍了详解如何用TensorFlow训练和识别/分类自定义图片,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • Python绘制柱状图可视化神器pyecharts

    Python绘制柱状图可视化神器pyecharts

    这篇文章主要介绍了Python绘制柱状图可视化神器pyecharts,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-09-09
  • Python机器学习NLP自然语言处理基本操作词袋模型

    Python机器学习NLP自然语言处理基本操作词袋模型

    本文是Python机器学习NLP自然语言处理系列文章,带大家开启一段学习自然语言处理 (NLP) 的旅程。本篇文章主要学习NLP自然语言处理基本操作之词袋模型
    2021-09-09
  • Django框架 信号调度原理解析

    Django框架 信号调度原理解析

    这篇文章主要介绍了Django框架 信号调度原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python中的asyncio代码详解

    Python中的asyncio代码详解

    asyncio 是用来编写 并发 代码的库,使用 async/await 语法。 asyncio 被用作多个提供高性能 Python 异步框架的基础,包括网络和网站服务,数据库连接库,分布式任务队列等等。这篇文章主要介绍了Python中的asyncio,需要的朋友可以参考下
    2019-06-06
  • Python数据结构与算法中的队列详解(2)

    Python数据结构与算法中的队列详解(2)

    这篇文章主要为大家详细介绍了Python中的队列,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-03-03
  • 详解OpenCV中简单的鼠标事件处理

    详解OpenCV中简单的鼠标事件处理

    谈及鼠标事件,就是在触发鼠标按钮后程序所做出相应的反应,但是不影响程序的整个线程。本文将主要介绍OpenCV中的简单鼠标事件处理,感兴趣的可以学习一下
    2022-01-01
  • python如何求100以内的素数

    python如何求100以内的素数

    在本篇文章里小编给大家分享的是关于python如何求100以内的素数的方法实例,需要的朋友们可以学习下。
    2020-05-05
  • YOLOv5改进之添加CBAM注意力机制的方法

    YOLOv5改进之添加CBAM注意力机制的方法

    注意力机制最先被用在NLP领域,Attention就是为了让模型认识到数据中哪一部分是最重要的,为它分配更大的权重,获得更多的注意力在一些特征上,让模型表现更好,这篇文章主要给大家介绍了关于YOLOv5改进之添加CBAM注意力机制的相关资料,需要的朋友可以参考下
    2022-11-11
  • Python OpenCV超详细讲解基本功能

    Python OpenCV超详细讲解基本功能

    OpenCV用C++语言编写,它具有C ++,Python,Java和MATLAB接口,并支持Windows,Linux,Android和Mac OS,OpenCV主要倾向于实时视觉应用,并在可用时利用MMX和SSE指令,本篇文章带你了解OpenCV的基本功能
    2022-04-04

最新评论