Pytorch中的masked_fill基本知识详解

 更新时间:2024年10月26日 10:06:41   作者:码农研究僧  
本文介绍了PyTorch中masked_fill函数的基本使用和原理,该函数接受一个输入张量和一个布尔掩码作为参数,掩码的形状必须与输入张量相同,True表示需要填充的位置,False表示保持原值

1. 基本知识

基本的原理知识如下:

输入张量和掩码
masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码
掩码的形状必须与输入张量相同,True 表示需要填充的位置,False 表示保持原值

掩码操作
在执行 masked_fill 操作时,函数会检查掩码中每个元素的值
如果掩码对应的位置为 True,则在输出张量中填充指定的值;
如果为 False,则保留输入张量中对应位置的值

输出结果
最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样

在代码逻辑上

创建掩码
mask 是一个布尔张量,标识了哪些位置需要填充:

[[False, True, False],
 [True, False, True],
 [False, False, True]]

执行 masked_fill
当调用 tensor.masked_fill(mask, -1) 时,PyTorch 会遍历掩码中的每个元素:对于 mask 中的每个 True 值,tensor 在对应位置的值会被替换为 -1,对于 False 值,保持原值不变

masked_fill 操作是基于 C/C++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时

2. Demo

Demo 1: 基本用法

import torch

# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[False, True, False],
                     [True, False, True],
                     [False, False, True]])

# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)

print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)

截图如下:

Demo 2: 与条件结合使用

import torch
# 创建一个随机张量
tensor = torch.randn(3, 3)
# 创建掩码:标记负值的位置
mask = tensor < 0
# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)

截图如下:

Demo 3: 结合计算

import torch
# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])
# 创建掩码:标记大于 50 的位置
mask = tensor > 50
# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)

截图如下:

到此这篇关于Pytorch中的masked_fill基本知识的文章就介绍到这了,更多相关Pytorch masked_fill内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Pycharm打印大数据文件显示不全的解决方法

    Pycharm打印大数据文件显示不全的解决方法

    这篇文章主要介绍了Pycharm打印大数据文件显示不全的解决方法,昨晚写了个小爬虫,简单分析下发现可以修改请求的url,直接获取所有目标的数据,想先打印在控制台看看,发现打印的数据不全,所以本文记录了一下解决方法,需要的朋友可以参考下
    2024-03-03
  • Python结合turtle简单开发一个烟花小工具

    Python结合turtle简单开发一个烟花小工具

    这篇文章主要为大家详细介绍了Python如何利用turtle模块实现的简单的烟花效果展示小工具,文章的示例代码讲解详细,感兴趣的小伙伴可以了解下
    2025-12-12
  • Pytest中Fixtures的高级用法

    Pytest中Fixtures的高级用法

    Fixtures 是 pytest 中一个非常强大的特性,它可以帮助我们提高测试的可维护性、可读性和可重复性,下面就来介绍一下,具有一定的参考价值,感兴趣的可以了解一下
    2025-05-05
  • Python 爬虫实现增加播客访问量的方法实现

    Python 爬虫实现增加播客访问量的方法实现

    这篇文章主要介绍了Python 爬虫实现增加播客访问量的方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • Python使用yaml模块操作YAML文档的方法

    Python使用yaml模块操作YAML文档的方法

    YAML是可读性高,用来表达数据序列化格式的,专用于写配置文件的语言,这篇文章主要介绍了Python使用yaml模块操作YAML文档,需要的朋友可以参考下
    2023-01-01
  • Python通过队列实现进程间通信详情

    Python通过队列实现进程间通信详情

    这篇文章主要介绍了Python通过队列实现进程间通信详情文章通过提出问题:在多进程中,每个进程之间是什么关系展开主题相关内容,感兴趣的朋友可以参考一下
    2022-06-06
  • Python解析网页源代码中的115网盘链接实例

    Python解析网页源代码中的115网盘链接实例

    这篇文章主要介绍了Python解析网页源代码中的115网盘链接实例,主要采用了正则表达式re模块来实现该功能,需要的朋友可以参考下
    2014-09-09
  • PyTorch 检查GPU版本是否安装成功的操作

    PyTorch 检查GPU版本是否安装成功的操作

    这篇文章主要介绍了PyTorch 检查GPU版本是否安装成功的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python命名空间的本质和加载顺序

    Python命名空间的本质和加载顺序

    这篇文章主要介绍了Python命名空间的本质和加载顺序,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2018-12-12
  • PyTorch张量操作指南(cat、stack、split与chunk)

    PyTorch张量操作指南(cat、stack、split与chunk)

    本文深入探讨PyTorch中用于调整张量结构的四个核心函数——torch.cat、torch.stack、torch.split和torch.chunk,通过实际应用场景分析和代码演示,帮助读者掌握它们的功能差异及适用条件,提升模型开发的灵活性与效率,需要的朋友可以参考下
    2025-04-04

最新评论