Pytorch中的masked_fill基本知识操作

 更新时间:2024年10月28日 10:02:55   作者:码农研究僧  
本文主要介绍了PyTorch中的masked_fill函数的基本知识和使用方法,masked_fill函数接受一个输入张量和一个布尔掩码作为主要参数,掩码的形状必须与输入张量相同,掩码操作根据掩码中的布尔值在输出张量中填充指定的值或保留输入张量中的值

1. 基本知识

基本的原理知识如下:

输入张量和掩码
masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码
掩码的形状必须与输入张量相同,True 表示需要填充的位置,False 表示保持原值

掩码操作
在执行 masked_fill 操作时,函数会检查掩码中每个元素的值
如果掩码对应的位置为 True,则在输出张量中填充指定的值;
如果为 False,则保留输入张量中对应位置的值

输出结果
最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样

在代码逻辑上

创建掩码
mask 是一个布尔张量,标识了哪些位置需要填充:

[[False, True, False],
 [True, False, True],
 [False, False, True]]

执行 masked_fill
当调用 tensor.masked_fill(mask, -1) 时,PyTorch 会遍历掩码中的每个元素:对于 mask 中的每个 True 值,tensor 在对应位置的值会被替换为 -1,对于 False 值,保持原值不变

masked_fill 操作是基于 C/C++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时

2. Demo

Demo 1: 基本用法

import torch
# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[False, True, False],
                     [True, False, True],
                     [False, False, True]])
# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)
print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)

截图如下:

Demo 2: 与条件结合使用

import torch
# 创建一个随机张量
tensor = torch.randn(3, 3)
# 创建掩码:标记负值的位置
mask = tensor < 0
# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)

截图如下:

Demo 3: 结合计算

import torch
# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])
# 创建掩码:标记大于 50 的位置
mask = tensor > 50
# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)

截图如下:

到此这篇关于Pytorch中的masked_fill基本知识的文章就介绍到这了,更多相关Pytorch masked_fill内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python基础之内置模块详解

    Python基础之内置模块详解

    Python内置的模块有很多,我们也已经接触了不少相关模块,接下来咱们就来做一些项目开发中常用的模块汇总和介绍,需要的朋友可以参考下
    2021-06-06
  • python3 adb 获取设备序列号的实现

    python3 adb 获取设备序列号的实现

    这篇文章主要介绍了python3 adb 获取设备序列号的实现操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • python获取指定路径下所有指定后缀文件的方法

    python获取指定路径下所有指定后缀文件的方法

    这篇文章主要介绍了python获取指定路径下所有指定后缀文件的方法,涉及Python针对文件与目录操作的相关技巧,需要的朋友可以参考下
    2015-05-05
  • python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题

    python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题

    这篇文章主要介绍了python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-01-01
  • Python实现的调用C语言函数功能简单实例

    Python实现的调用C语言函数功能简单实例

    这篇文章主要介绍了Python实现的调用C语言函数功能,结合简单实例形式分析了Python使用ctypes模块调用C语言函数的具体步骤与相关操作技巧,需要的朋友可以参考下
    2019-03-03
  • 使用Python实现跳一跳自动跳跃功能

    使用Python实现跳一跳自动跳跃功能

    这篇文章主要介绍了使用Python实现跳一跳自动跳跃功能,本文图文并茂通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • python简单实例训练(21~30)

    python简单实例训练(21~30)

    上篇文章给大家介绍了python简单实例训练的1-10,这里继续为大家介绍python的一些用法,希望大家每个例子都打出来测试一下
    2017-11-11
  • PyTorch中的参数类torch.nn.Parameter()详解

    PyTorch中的参数类torch.nn.Parameter()详解

    这篇文章主要给大家介绍了关于PyTorch中torch.nn.Parameter()的相关资料,要内容包括基础应用、实用技巧、原理机制等方面,文章通过实例介绍的非常详细,需要的朋友可以参考下
    2022-02-02
  • Python定义一个函数的方法

    Python定义一个函数的方法

    这篇文章主要介绍了Python定义一个函数的方法及相关实例,需要的朋友们可以学习参考下。
    2020-06-06
  • 浅谈django2.0 ForeignKey参数的变化

    浅谈django2.0 ForeignKey参数的变化

    今天小编就为大家分享一篇浅谈django2.0 ForeignKey参数的变化,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08

最新评论