PyTorch中torch.argmax函数的使用

 更新时间:2025年05月13日 10:09:37   作者:Code_Geo  
torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景,下面就来介绍一下PyTorch中torch.argmax函数的使用,感兴趣的可以了解一下

torch.argmax 是 PyTorch 中的一个函数,用于返回输入张量中最大值所在的索引。其作用与数学中的 ​argmax 概念一致,即找到某个函数在指定范围内取得最大值时的参数(位置索引

函数定义

torch.argmax(input, dim=None, keepdim=False)
  • ​输入:
    • input:输入张量。
    • dim(可选):指定沿哪个维度查找最大值。如果为 None,则在整个张量中查找。
    • keepdim(可选):是否保持输出张量的维度与输入一致(默认为 False)。
  • ​输出:
    一个张量,包含最大值所在的索引

核心功能

1、​全局最大值索引​(当 dim=None)

  • 将输入张量展平后,返回最大值的索引
import torch

x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x))  # 输出:tensor(3)
# 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值为6,索引为3(从0开始)

2|​沿指定维度查找最大值索引​(当 dim 指定时)

  • 沿 dim 维度对输入张量操作,返回每行/列的最大值索引
# 沿行维度(dim=1)查找
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x, dim=1))  # 输出:tensor([2, 0])
# 解释:
# 第一行 [1, 2, 3] 最大值3,索引2
# 第二行 [6, 5, 4] 最大值6,索引0

# 沿列维度(dim=0)查找
print(torch.argmax(x, dim=0))  # 输出:tensor([1, 1, 0])
# 解释:
# 第0列 [1, 6] 最大值6,索引1
# 第1列 [2, 5] 最大值5,索引1
# 第2列 [3, 4] 最大值4,索引1(但此处输出为0,可能有误,实际应为1)

参数详解

1. dim 参数

  • ​作用:指定沿哪个维度操作。
  • ​示例:
    • dim=0:沿列操作(纵向)。
    • dim=1:沿行操作(横向)。

2. keepdim 参数

  • ​作用:保持输出维度与输入一致。
  • ​示例:
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
out = torch.argmax(x, dim=1, keepdim=True)
print(out)  # 输出:tensor([[2], [0]])

常见用途

1、​分类任务中获取预测标签

logits = torch.tensor([0.1, 0.8, 0.05, 0.05])  # 模型输出的概率分布
predicted_class = torch.argmax(logits)         # 输出:tensor(1)

2、​计算准确率

# 假设batch_size=4,num_classes=3
preds = torch.tensor([[0.1, 0.2, 0.7],
                      [0.9, 0.05, 0.05],
                      [0.3, 0.4, 0.3],
                      [0.05, 0.8, 0.15]])
labels = torch.tensor([2, 0, 1, 1])
# 获取预测类别
predicted_classes = torch.argmax(preds, dim=1)  # 输出:tensor([2, 0, 1, 1])
# 计算正确预测数
correct = (predicted_classes == labels).sum()   # 输出:tensor(3)

注意事项

1、​多个相同最大值:

  • 如果存在多个相同的最大值,返回第一个出现的索引
x = torch.tensor([3, 1, 4, 4])
print(torch.argmax(x))  # 输出:tensor(2)

2、​数据类型

  • 输入张量应为数值类型(如 float32、int64)

3、​维度合法性

  • 如果指定了不存在的维度(如 dim=3 对一个二维张量),会触发错误

总结

torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景。理解其 dim 和 keepdim 参数的行为,可以灵活处理不同维度的数据

到此这篇关于PyTorch中torch.argmax函数的使用的文章就介绍到这了,更多相关PyTorch torch.argmax内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:

相关文章

  • python实现交并比IOU教程

    python实现交并比IOU教程

    这篇文章主要介绍了python实现交并比IOU教程,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • 探索Python内置数据类型的精髓与应用

    探索Python内置数据类型的精髓与应用

    本文探索Python内置数据类型的精髓与应用,包括字符串、列表、元组、字典和集合。通过深入了解它们的特性、操作和常见用法,读者将能够更好地利用这些数据类型解决实际问题。
    2023-09-09
  • python整合ffmpeg实现视频文件的批量转换

    python整合ffmpeg实现视频文件的批量转换

    这篇文章主要为大家详细介绍了python整合ffmpeg实现视频文件的批量转换,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-05-05
  • Sentry的安装、配置、使用教程(Sentry日志手机系统)

    Sentry的安装、配置、使用教程(Sentry日志手机系统)

    Sentry 是一个实时事件日志记录和聚合平台,由于ExceptionLess官方提供的客户端只有.Net/.NetCore平台和js的,本文继续介绍另一个日志收集系统Sentry,感兴趣的朋友一起看看吧
    2022-07-07
  • 用python打印1~20的整数实例讲解

    用python打印1~20的整数实例讲解

    在本篇内容中小编给大家分享了关于python打印1~20的整数的具体步骤以及实例方法,需要的朋友们参考下。
    2019-07-07
  • Python数据处理的六种方式总结

    Python数据处理的六种方式总结

    在 Python 的数据处理方面经常会用到一些比较常用的数据处理方式,比如pandas、numpy等等。今天介绍的这款 Python 数据处理的管道数据处理方式,通过链式函数的方式可以轻松的完成对list列表数据的处理,希望对大家有所帮助
    2022-11-11
  • Python实现UDP程序通信过程图解

    Python实现UDP程序通信过程图解

    这篇文章主要介绍了Python实现UDP程序通信过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • python ES连接服务器的方法详解

    python ES连接服务器的方法详解

    使用Python连接Elasticsearch服务器进行数据搜索和分析是一项常见操作,本文详细介绍了如何使用elasticsearch-py客户端库连接到Elasticsearch服务器,并执行创建索引、添加文档及搜索等基本操作
    2024-10-10
  • 简单瞅瞅Python vars()内置函数的实现

    简单瞅瞅Python vars()内置函数的实现

    这篇文章主要介绍了简单瞅瞅Python vars()内置函数的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-09-09
  • Python3.7 读取 mp3 音频文件生成波形图效果

    Python3.7 读取 mp3 音频文件生成波形图效果

    这篇文章主要介绍了Python3.7 读取 mp3 音频文件生成波形图小编,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-11-11

最新评论