PyTorch torch.unique() 基础与实战应用指南

 更新时间:2025年10月29日 11:23:04   作者:Geoking.  
torch.unique() 是PyTorch中的一个去重函数,用于返回张量中所有的唯一元素(unique elements),本文将带你深入了解 torch.unique() 的用法、参数、返回值以及实际应用场景,感兴趣的朋友跟随小编一起看看吧

在深度学习的数据处理中经常需要统计或筛选 张量(Tensor) 中的唯一值,比如去重、统计类别数量、计算唯一标签数等。
PyTorch 提供了一个非常方便的函数 —— torch.unique(),可以轻松完成这些操作。

本文将带你深入了解 torch.unique() 的用法、参数、返回值以及实际应用场景。

一、什么是torch.unique()?

torch.unique() 是 PyTorch 中的一个去重函数,用于返回张量中所有的唯一元素(unique elements)。

它类似于 Python 的 set() 或 NumPy 的 np.unique(),但专为 GPU 加速的张量操作 设计。

二、函数语法

torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)

三、参数说明

参数类型说明
inputTensor输入张量
sortedbool是否对结果排序(默认 True
return_inversebool是否返回原张量中每个值在唯一值列表中的索引
return_countsbool是否返回每个唯一值的出现次数
dimintNone按指定维度去重,默认对整个张量去重

四、基本用法

🎯 示例 1:基础去重

import torch
x = torch.tensor([1, 2, 2, 3, 3, 3])
unique_x = torch.unique(x)
print(unique_x)

输出:

tensor([1, 2, 3])

✅ 结果去除了重复值,并自动排序。

🎯 示例 2:不排序

x = torch.tensor([3, 2, 1, 3, 2])
unique_x = torch.unique(x, sorted=False)
print(unique_x)

输出:

tensor([3, 2, 1])

sorted=False 时,结果的顺序与首次出现的顺序一致。

五、返回索引与计数

🎯 示例 3:return_inverse

return_inverse=True 会返回一个索引张量,表示原张量中每个元素在唯一值(即新张量)中的位置。

x = torch.tensor([2, 1, 2, 3])
u, inv = torch.unique(x, return_inverse=True)
print(u)
print(inv)

输出:

tensor([1, 2, 3])
tensor([1, 0, 1, 2])

解释:

  • 唯一值为 [1, 2, 3]
  • 原数组 [2, 1, 2, 3] 中:
    • 第一个元素 2 → 索引 1
    • 第二个元素 1 → 索引 0
    • 第三个元素 2 → 索引 1
    • 第四个元素 3 → 索引 2

🎯 示例 4:return_counts

return_counts=True 会返回每个唯一值出现的次数。

x = torch.tensor([1, 2, 2, 3, 3, 3])
u, counts = torch.unique(x, return_counts=True)
print(u)
print(counts)

输出:

tensor([1, 2, 3])
tensor([1, 2, 3])

表示:

  • 值 1 出现 1 次
  • 值 2 出现 2 次
  • 值 3 出现 3 次

🎯 示例 5:同时返回多个结果

你可以同时返回 unique 值、inverse 索引和计数

x = torch.tensor([1, 2, 2, 3, 3, 3])
u, inv, counts = torch.unique(x, return_inverse=True, return_counts=True)
print(u)
print(inv)
print(counts)

输出:

tensor([1, 2, 3])
tensor([0, 1, 1, 2, 2, 2])
tensor([1, 2, 3])

六、按维度去重(dim 参数)

默认情况下,torch.unique() 会将张量展开成一维后去重。
但如果你希望在特定维度上去重(如按行或按列),可以使用 dim 参数。

🎯 示例 6:按行去重

x = torch.tensor([[1, 2],
                  [1, 2],
                  [3, 4]])
unique_rows = torch.unique(x, dim=0)
print(unique_rows)

输出:

tensor([[1, 2],
        [3, 4]])

表示第 1、2 行重复,只保留一个。

🎯 示例 7:按列去重

x = torch.tensor([[1, 1, 3],
                  [2, 2, 4]])
unique_cols = torch.unique(x, dim=1)
print(unique_cols)

输出:

tensor([[1, 3],
        [2, 4]])

七、torch.unique()与 NumPy 对比

功能PyTorch (torch.unique)NumPy (np.unique)
默认排序✅ 是✅ 是
支持 GPU✅ 是❌ 否
返回 inverse 索引✅ 是✅ 是
返回 counts✅ 是✅ 是
按维度去重✅ 是(dim❌ 不直接支持
性能高(GPU 支持)仅 CPU

八、实际应用场景

1. 分类问题中统计类别数量

labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
classes = torch.unique(labels)
print(f"共有 {len(classes)} 个类别: {classes.tolist()}")

输出:

共有 4 个类别: [0, 1, 2, 3]

2. 计算样本分布(类别频率)

labels = torch.tensor([0, 1, 0, 2, 2, 1, 3])
u, counts = torch.unique(labels, return_counts=True)
for c, cnt in zip(u.tolist(), counts.tolist()):
    print(f"类别 {c}: {cnt} 个样本")

输出:

类别 0: 2 个样本
类别 1: 2 个样本
类别 2: 2 个样本
类别 3: 1 个样本

3. 在图像分割中统计像素类别

例如在语义分割任务中,计算 mask 图像中有多少个不同的像素类别:

mask = torch.randint(0, 5, (256, 256))  # 随机生成类别标签
num_classes = len(torch.unique(mask))
print(f"图像中共有 {num_classes} 个类别")

⚠️ 九、注意事项

  1. torch.unique()** 默认会对结果排序**,如果在意性能,可以设置 sorted=False
  2. 对高维张量使用 dim 去重时,必须保证该维度的所有元素形状一致。
  3. 对大张量使用 return_countsreturn_inverse 时可能会消耗更多显存。

📚 参考资料

NumPy 官方文档 – numpy.unique

到此这篇关于PyTorch torch.unique() 基础与实战应用指南的文章就介绍到这了,更多相关PyTorch torch.unique() 使用内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python深度学习tensorflow卷积层示例教程

    python深度学习tensorflow卷积层示例教程

    这篇文章主要为大家介绍了python深度学习tensorflow卷积层示例教程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Python return语句如何实现结果返回调用

    Python return语句如何实现结果返回调用

    这篇文章主要介绍了Python return语句如何实现结果返回调用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-10-10
  • Python实现自动化处理Word文档的方法详解

    Python实现自动化处理Word文档的方法详解

    本文主要介绍了如何使用Python实现Word文档的自动化处理,包括批量生成Word文档、在Word文档中批量进行查找和替换、将Word文档批量转换成PDF等,希望对你有所帮助
    2022-08-08
  • 利用Python的装饰器解决Bottle框架中用户验证问题

    利用Python的装饰器解决Bottle框架中用户验证问题

    这篇文章主要介绍了Python的Bottle框架中解决用户验证问题,代码基于Python2.x版本,需要的朋友可以参考下
    2015-04-04
  • 关于Series的index的方法和属性使用说明

    关于Series的index的方法和属性使用说明

    这篇文章主要介绍了关于Series的index的方法和属性使用说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-06-06
  • Pyinstaller打包工具的使用以及避坑

    Pyinstaller打包工具的使用以及避坑

    本文主要的是pyinstaller在windows下的基本使用和基础避坑,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-11-11
  • wxPython实现文本框基础组件

    wxPython实现文本框基础组件

    这篇文章主要介绍了wxPython实现文本框基础组件,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-11-11
  • python中有关时间日期格式转换问题

    python中有关时间日期格式转换问题

    这篇文章主要介绍了python中有关时间日期格式转换问题,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-12-12
  • Python实现批量更换指定目录下文件扩展名的方法

    Python实现批量更换指定目录下文件扩展名的方法

    这篇文章主要介绍了Python实现批量更换指定目录下文件扩展名的方法,结合完整实例分析了Python批量修改文件扩展名的技巧,并对比分析了shell命令及scandir的兼容性代码,需要的朋友可以参考下
    2016-09-09
  • python组合无重复三位数的实例

    python组合无重复三位数的实例

    今天小编就为大家分享一篇python组合无重复三位数的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11

最新评论