Pytorch之上/下采样函数torch.nn.functional.interpolate插值详解

 更新时间:2025年04月16日 09:31:12   作者:Yuezero_  
这篇文章主要介绍了Pytorch之上/下采样函数torch.nn.functional.interpolate插值,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

Pytorch上/下采样函数torch.nn.functional.interpolate插值

torch.nn.functional.interpolate(input_tensor, size=None, scale_factor=8, mode='bilinear', align_corners=False)
'''
Down/up samples the input to either the given size or the given scale_factor
The algorithm used for interpolation is determined by mode.
Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape.
The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width.
The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area
'''

这个函数是用来上采样下采样tensor的空间维度(h,w)

input_tensor支持输入3D (b, c, w)或(batch,seq_len,dim)、4D (b, c, h, w)、5D (b, c, f, h, w)的 tensor shape。其中b表示batch_size,c表示channel,f表示frames,h表示height,w表示weight。

size是目标tensor的(w)/(h,w)/(f,h,w)的形状;scale_factor是采样tensor的saptial shape(w)/(h,w)/(f,h,w)的缩放系数,sizescale_factor两个参数只能定义一个,具体是上采样,还是下采样根据这两个参数判断。如果size或者scale_factorlist序列,则必须匹配输入的大小。

  • 如果输入3D,则它们的序列长度必须是1(只缩放最后1个维度w)。
  • 如果输入4D,则它们的序列长度必须是2(缩放最后2个维度h,w)。
  • 如果输入是5D,则它们的序列长度必须是3(缩放最后3个维度f,h,w)。

插值算法mode可选:最近邻(nearest, 默认)线性(linear, 3D-only)双线性(bilinear, 4D-only)三线性(trilinear, 5D-only)等等。

是否align_corners对齐角点:可选的bool值, 如果 align_corners=True,则对齐 input 和 output 的角点像素(corner pixels),保持在角点像素的值. 只会对 mode=linear, bilinear, trilinear 有作用. 默认是 False。一图看懂align_corners=TrueFalse的区别,从4×4上采样成8×8。

一个是按四角的像素点中心对齐,另一个是按四角的像素角点对齐:

import torch
import torch.nn.functional as F
b, c, f, h, w = 1, 3, 8, 64, 64

1. upsample/downsample 3D tensor

# interpolate 3D tensor
x = torch.randn([b, c, w])
## downsample to (b, c, w/2)
y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
y1 = F.interpolate(x, size=[w//2], mode='nearest')
y2 = F.interpolate(x, scale_factor=0.5, mode='linear')  # only 3D
y3 = F.interpolate(x, size=[w//2], mode='linear')  # only 3D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32])

## upsample to (b, c, w*2)
y0 = F.interpolate(x, scale_factor=2, mode='nearest')
y1 = F.interpolate(x, size=[w*2], mode='nearest')
y2 = F.interpolate(x, scale_factor=2, mode='linear')  # only 3D
y3 = F.interpolate(x, size=[w*2], mode='linear')  # only 3D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128])

2. upsample/downsample 4D tensor

# interpolate 4D tensor
x = torch.randn(b, c, h, w)
## downsample to (b, c, h/2, w/2)
y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
y1 = F.interpolate(x, size=[h//2, w//2], mode='nearest')
y2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')  # only 4D
y3 = F.interpolate(x, size=[h//2, w//2], mode='bilinear')  # only 4D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32])

## upsample to (b, c, h*2, w*2)
y0 = F.interpolate(x, scale_factor=2, mode='nearest')
y1 = F.interpolate(x, size=[h*2, w*2], mode='nearest')
y2 = F.interpolate(x, scale_factor=2, mode='bilinear')  # only 4D
y3 = F.interpolate(x, size=[h*2, w*2], mode='bilinear')  # only 4D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128])

3. upsample/downsample 5D tensor

# interpolate 5D tensor
x = torch.randn(b, c, f, h, w)
## downsample to (b, c, f/2, h/2, w/2)
y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
y1 = F.interpolate(x, size=[f//2, h//2, w//2], mode='nearest')
y2 = F.interpolate(x, scale_factor=2, mode='trilinear')  # only 5D
y3 = F.interpolate(x, size=[f//2, h//2, w//2], mode='trilinear')  # only 5D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 4, 32, 32])

## upsample to (b, c, f*2, h*2, w*2)
y0 = F.interpolate(x, scale_factor=2, mode='nearest')
y1 = F.interpolate(x, size=[f*2, h*2, w*2], mode='nearest')
y2 = F.interpolate(x, scale_factor=2, mode='trilinear')  # only 5D
y3 = F.interpolate(x, size=[f*2, h*2, w*2], mode='trilinear')  # only 5D
print(y0.shape, y1.shape, y2.shape, y3.shape)
# torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python 通过监听端口实现唯一脚本运行方式

    Python 通过监听端口实现唯一脚本运行方式

    这篇文章主要介绍了Python 通过监听端口实现唯一脚本运行方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • pytorch 固定部分参数训练的方法

    pytorch 固定部分参数训练的方法

    今天小编就为大家分享一篇pytorch 固定部分参数训练的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • 一步步教你用Python画五彩气球

    一步步教你用Python画五彩气球

    这篇文章主要给大家介绍了关于如何用Python画五彩气球的相关资料,主要是用turtle库自带的画笔turtle.Turtle()来绘制气球,文中给出了详细的实例代码,需要的朋友可以参考下
    2023-06-06
  • Python通过Pillow实现图片对比

    Python通过Pillow实现图片对比

    这篇文章主要介绍了Python Pillow实现图片对比,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • Python实现八大排序算法

    Python实现八大排序算法

    这篇文章主要介绍了Python实现八大排序算法,如何用Python实现八大排序算法,感兴趣的小伙伴们可以参考一下
    2016-08-08
  • pandas 如何保存数据到excel,csv

    pandas 如何保存数据到excel,csv

    这篇文章主要介绍了pandas 如何保存数据到excel,csv的实现方式,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07
  • Gauss-Seidel迭代算法的Python实现详解

    Gauss-Seidel迭代算法的Python实现详解

    这篇文章主要介绍了Gauss-Seidel迭代算法的Python实现详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-06-06
  • 一文带你了解Python中Scikit-learn库的使用

    一文带你了解Python中Scikit-learn库的使用

    Scikit-learn是Python的一个开源机器学习库,它支持监督和无监督学习,本文主要来深入探讨一下Scikit-learn的更高级的特性,感兴趣的小伙伴可以了解下
    2023-07-07
  • python初学之用户登录的实现过程(实例讲解)

    python初学之用户登录的实现过程(实例讲解)

    下面小编就为大家分享一篇python初学之用户登录的实现过程(实例讲解),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2017-12-12
  • Python实现的简单读写csv文件操作示例

    Python实现的简单读写csv文件操作示例

    这篇文章主要介绍了Python实现的简单读写csv文件操作,结合实例形式分析了Python使用csv模块针对csv文件进行读写操作的相关实现技巧与注意事项,需要的朋友可以参考下
    2018-07-07

最新评论