Pytorch中的torch.where函数使用

 更新时间:2024年02月26日 09:47:38   作者:烟雨风渡  
这篇文章主要介绍了Pytorch中的torch.where函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

使用torch.where函数

首先我们看一下Pytorch中torch.where函数是怎样定义的:

@overload
def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...

torch.where函数的功能如下:

torch.where(condition, x, y)

  • condition:判断条件
  • x:若满足条件,则取x中元素
  • y:若不满足条件,则取y中元素

以具体实例看一下torch.where函数的效果:

import torch
 
# 条件
condition = torch.rand(3, 2)
print(condition)
# 满足条件则取x中对应元素
x = torch.ones(3, 2)
print(x)
# 不满足条件则取y中对应元素
y = torch.zeros(3, 2)
print(y)
# 条件判断后的结果
result = torch.where(condition > 0.5, x, y)
print(result)

结果如下:

tensor([[0.3224, 0.5789],
        [0.8341, 0.1673],
        [0.1668, 0.4933]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([[0., 1.],
        [1., 0.],
        [0., 0.]])

可以看到torch.where函数会对condition中的元素逐一进行判断,根据判断的结果选取x或y中的值,所以要求x和y应该与condition形状相同。

torch.where(),np.where()两种用法,及np.argwhere()寻找张量(tensor)和数组中为0的索引

1.torch.where()

torch.where()有两种用法,

  • 当输入参数为三个时,即torch.where(condition, x, y),返回满足 x if condition else y的tensor,注意x,y必须为tensor
  • 当输入参数为一个时,即torch.where(condition),返回满足condition的tensor索引的元组(tuple)

代码示例

torch.where(condition, x, y)

代码

import torch
import numpy as np
 
# 初始化两个tensor
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])
 
# 寻找满足x中大于3的元素,否则得到y对应位置的元素
arr0 = torch.where(x>=3, x, y) #输入参数为3个
 
print(x, '\n', y)
print(arr0, '\n', type(arr0))

结果

>>> x
tensor([[1, 2, 3, 0, 6],
        [4, 6, 2, 1, 0],
        [4, 3, 0, 1, 1]])
>>> y
tensor([[0, 5, 1, 4, 2],
        [5, 7, 1, 2, 9],
        [1, 3, 5, 6, 6]])
 
>>> arr0
tensor([[0, 5, 3, 4, 6],
        [4, 6, 1, 2, 9],
        [4, 3, 5, 6, 6]])
 
>>> type(arr0)
<class 'torch.Tensor'>

arr0的类型为<class 'torch.Tensor'>

torch.where(condition)

以寻找tensor中为0的索引为例

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])
 
# 返回x中0元素的索引
index0 = torch.where(x==0) # 输入参数为1个
 
print(index0,'\n', type(index0))

结果

>>> index0
(tensor([0, 1, 2]), tensor([3, 4, 2])) 
 
>>> type(index0)
<class 'tuple'>

其中[0, 1, 2]是0元素坐标的行索引,[3, 4, 2]是0元素坐标的列索引,注意,最终得到的是tuple类型的返回值,元组中包含了tensor

2.np.where()

np.where()用法与torch.where()用法类似,也包括两种用法,但是不同的是输入值类型和返回值的类型

代码示例

np.where(condition, x, y)和np.where(condition),输入x,y可以为非tensor

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])
 
arr1 = np.where(x>=3, x, y) # 输入参数为3个
 
index0 = torch.where(x==0) # 输入参数为1个
 
print(arr1,'\n',type(arr1))
print(index1,'\n', type(index1))
 

结果

>>> arr1
[[0 5 3 4 6]
 [4 6 1 2 9]
 [4 3 5 6 6]]
 
>>> type(arr1)
<class 'numpy.ndarray'>
 
>>> index1
(array([0, 1, 2]), array([3, 4, 2])) 
 
>>> type(index1)
<class 'tuple'>

注意,np.where()和torch.where()的返回值类型不同

3.np.argwhere(condition)

寻找符合contion的元素索引

代码示例

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])
 
 
index2 = np.argwhere(x==0) # 寻找元素为0的索引
 
print(index2,'\n', type(index2))

结果

>>> index2
tensor([[0, 1, 2],
        [3, 4, 2]]) 
 
>>> type(index2)
<class 'torch.Tensor'>

注意返回值的类型

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python通过pil模块将raw图片转换成png图片的方法

    python通过pil模块将raw图片转换成png图片的方法

    这篇文章主要介绍了python通过pil模块将raw图片转换成png图片的方法,实例分析了Python中pil模块的使用技巧,并Image.fromstring函数进行了较为详尽的分析说明,需要的朋友可以参考下
    2015-03-03
  • Python中命名元组Namedtuple的使用详解

    Python中命名元组Namedtuple的使用详解

    Python支持一种名为“namedtuple()”的容器字典,它存在于模块“collections”中,下面就跟随小编一起学习一下Namedtuple的具体使用吧
    2023-09-09
  • 30道python自动化测试面试题与答案汇总

    30道python自动化测试面试题与答案汇总

    对于机器学习算法工程师而言,Python是不可或缺的语言,它的优美与简洁令人无法自拔,下面这篇文章主要给大家介绍了关于30道python自动化测试面试题与答案汇总的相关资料,需要的朋友可以参考下
    2023-03-03
  • Python使用pydub模块转换音频格式以及对音频进行剪辑

    Python使用pydub模块转换音频格式以及对音频进行剪辑

    这篇文章主要给大家介绍了关于Python使用pydub模块转换音频格式以及对音频进行剪辑的相关资料pydub是python的高级一个音频处理库,可以让你以一种不那么蠢的方法处理音频。需要的朋友可以参考下
    2021-06-06
  • Python设计模式之备忘录模式原理与用法详解

    Python设计模式之备忘录模式原理与用法详解

    这篇文章主要介绍了Python设计模式之备忘录模式原理与用法,结合实例形式详细分析了备忘录模式的相关概念、原理及Python相关实现技巧,需要的朋友可以参考下
    2019-01-01
  • Python实现的生产者、消费者问题完整实例

    Python实现的生产者、消费者问题完整实例

    这篇文章主要介绍了Python实现的生产者、消费者问题,简单描述了生产者、消费者问题的概念、原理,并结合完整实例形式分析了Python实现生产者、消费者问题的相关操作技巧,需要的朋友可以参考下
    2018-05-05
  • python 日志模块logging的使用场景及示例

    python 日志模块logging的使用场景及示例

    这篇文章主要介绍了python 日志模块logging的使用场景及示例,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-01-01
  • 用Python徒手撸一个股票回测框架搭建【推荐】

    用Python徒手撸一个股票回测框架搭建【推荐】

    回测框架就是提供这样的一个平台让交易策略在历史数据中不断交易,最终生成最终结果,通过查看结果的策略收益,年化收益,最大回测等用以评估交易策略的可行性。这篇文章主要介绍了用Python徒手撸一个股票回测框架,需要的朋友可以参考下
    2019-08-08
  • Python之时间和日期使用小结

    Python之时间和日期使用小结

    这篇文章主要介绍了Python之时间和日期使用小结,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-02-02
  • 深度学习TextLSTM的tensorflow1.14实现示例

    深度学习TextLSTM的tensorflow1.14实现示例

    这篇文章主要为大家介绍了深度学习TextLSTM的tensorflow1.14实现示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-01-01

最新评论