PyTorch中torch.argmax函数的使用

 更新时间:2025年05月13日 10:09:37   作者:Code_Geo  
torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景,下面就来介绍一下PyTorch中torch.argmax函数的使用,感兴趣的可以了解一下

torch.argmax 是 PyTorch 中的一个函数,用于返回输入张量中最大值所在的索引。其作用与数学中的 ​argmax 概念一致,即找到某个函数在指定范围内取得最大值时的参数(位置索引

函数定义

torch.argmax(input, dim=None, keepdim=False)
  • ​输入:
    • input:输入张量。
    • dim(可选):指定沿哪个维度查找最大值。如果为 None,则在整个张量中查找。
    • keepdim(可选):是否保持输出张量的维度与输入一致(默认为 False)。
  • ​输出:
    一个张量,包含最大值所在的索引

核心功能

1、​全局最大值索引​(当 dim=None)

  • 将输入张量展平后,返回最大值的索引
import torch

x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x))  # 输出:tensor(3)
# 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值为6,索引为3(从0开始)

2|​沿指定维度查找最大值索引​(当 dim 指定时)

  • 沿 dim 维度对输入张量操作,返回每行/列的最大值索引
# 沿行维度(dim=1)查找
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x, dim=1))  # 输出:tensor([2, 0])
# 解释:
# 第一行 [1, 2, 3] 最大值3,索引2
# 第二行 [6, 5, 4] 最大值6,索引0

# 沿列维度(dim=0)查找
print(torch.argmax(x, dim=0))  # 输出:tensor([1, 1, 0])
# 解释:
# 第0列 [1, 6] 最大值6,索引1
# 第1列 [2, 5] 最大值5,索引1
# 第2列 [3, 4] 最大值4,索引1(但此处输出为0,可能有误,实际应为1)

参数详解

1. dim 参数

  • ​作用:指定沿哪个维度操作。
  • ​示例:
    • dim=0:沿列操作(纵向)。
    • dim=1:沿行操作(横向)。

2. keepdim 参数

  • ​作用:保持输出维度与输入一致。
  • ​示例:
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
out = torch.argmax(x, dim=1, keepdim=True)
print(out)  # 输出:tensor([[2], [0]])

常见用途

1、​分类任务中获取预测标签

logits = torch.tensor([0.1, 0.8, 0.05, 0.05])  # 模型输出的概率分布
predicted_class = torch.argmax(logits)         # 输出:tensor(1)

2、​计算准确率

# 假设batch_size=4,num_classes=3
preds = torch.tensor([[0.1, 0.2, 0.7],
                      [0.9, 0.05, 0.05],
                      [0.3, 0.4, 0.3],
                      [0.05, 0.8, 0.15]])
labels = torch.tensor([2, 0, 1, 1])
# 获取预测类别
predicted_classes = torch.argmax(preds, dim=1)  # 输出:tensor([2, 0, 1, 1])
# 计算正确预测数
correct = (predicted_classes == labels).sum()   # 输出:tensor(3)

注意事项

1、​多个相同最大值:

  • 如果存在多个相同的最大值,返回第一个出现的索引
x = torch.tensor([3, 1, 4, 4])
print(torch.argmax(x))  # 输出:tensor(2)

2、​数据类型

  • 输入张量应为数值类型(如 float32、int64)

3、​维度合法性

  • 如果指定了不存在的维度(如 dim=3 对一个二维张量),会触发错误

总结

torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景。理解其 dim 和 keepdim 参数的行为,可以灵活处理不同维度的数据

到此这篇关于PyTorch中torch.argmax函数的使用的文章就介绍到这了,更多相关PyTorch torch.argmax内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:

相关文章

  • python中pygame模块用法实例

    python中pygame模块用法实例

    这篇文章主要介绍了python中pygame模块用法实例,通过图形绘制来简单讲述了pygame模块的用法,具有很好的参考借鉴价值,需要的朋友可以参考下
    2014-10-10
  • 实例讲解python函数式编程

    实例讲解python函数式编程

    这篇文章主要介绍了python函数式编程实例,使用一个例子来阐述python函数式编程,需要的朋友可以参考下
    2014-06-06
  • Python读取和处理分析tif数据的超详细教程

    Python读取和处理分析tif数据的超详细教程

    TIF格式是一种跨平台的图片格式,可同时支持Windows和Mac系统的操作,TIF格式可以在保证图片不失真的情况下压缩,且保留图片的分层或是透明信息,这篇文章主要介绍了Python读取和处理分析tif数据的相关资料,需要的朋友可以参考下
    2025-11-11
  • Python使用剪切板的方法

    Python使用剪切板的方法

    这篇文章主要为大家详细介绍了Python使用剪切板的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-06-06
  • Numpy中np.max的用法及np.maximum区别

    Numpy中np.max的用法及np.maximum区别

    这篇文章主要介绍了Numpy中np.max的用法及np.maximum区别,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python+Pygame实现之走四棋儿游戏的实现

    Python+Pygame实现之走四棋儿游戏的实现

    大家以前应该都听说过一个游戏:叫做走四棋儿。直接在家里的水泥地上用烧完的炭火灰画出几条线,摆上几颗石头子即可。当时的火爆程度可谓是达到了一个新的高度。本文将利用Pygame实现这一游戏,需要的可以参考一下
    2022-07-07
  • 8个让Python代码效率翻倍的简单技巧

    8个让Python代码效率翻倍的简单技巧

    这篇文章主要为大家详细介绍了8个让Python代码效率翻倍的简单技巧,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以参考一下
    2026-01-01
  • 在pycharm中配置Anaconda以及pip源配置详解

    在pycharm中配置Anaconda以及pip源配置详解

    这篇文章主要介绍了在pycharm中配置Anaconda以及pip源配置详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-09-09
  • 使用Python的Twisted框架编写简单的网络客户端

    使用Python的Twisted框架编写简单的网络客户端

    这篇文章主要介绍了使用Python的Twisted框架编写简单的网络客户端,翻译自Twisted文档,包括一个简单的IRC客户端的实现,需要的朋友可以参考下
    2015-04-04
  • PyTorch一小时掌握之神经网络分类篇

    PyTorch一小时掌握之神经网络分类篇

    这篇文章主要介绍了PyTorch一小时掌握之神经网络分类篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-09-09

最新评论