深度学习入门之Pytorch 数据增强的实现

 更新时间:2020年02月26日 14:40:43   作者:GQ17-kawaler  
这篇文章主要介绍了深度学习入门之Pytorch 数据增强的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im


上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()


可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)


从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注脚本之家其它相关文章!

相关文章

  • 浅谈python已知元素,获取元素索引(numpy,pandas)

    浅谈python已知元素,获取元素索引(numpy,pandas)

    今天小编就为大家分享一篇浅谈python已知元素,获取元素索引(numpy,pandas),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • Python格式化日期时间操作示例

    Python格式化日期时间操作示例

    这篇文章主要介绍了Python格式化日期时间操作,结合实例形式分析了Python使用time模块针对日期时间的获取、转换、格式化、判断等相关操作技巧,需要的朋友可以参考下
    2018-06-06
  • 使用tqdm显示Python代码执行进度功能

    使用tqdm显示Python代码执行进度功能

    在使用Python执行一些比较耗时的操作时,为了方便观察进度,通常使用进度条的方式来可视化呈现。这篇文章主要介绍了使用tqdm显示Python代码执行进度,需要的朋友可以参考下
    2019-12-12
  • Pycharm小白级简单使用教程

    Pycharm小白级简单使用教程

    pycharm是一种Python IDE,能够帮助我们在编写代码时提高效率。 这篇文章主要介绍了Pycharm小白级简单使用教程,需要的朋友可以参考下
    2020-01-01
  • Python 通配符删除文件的实例

    Python 通配符删除文件的实例

    下面小编就为大家分享一篇Python 通配符删除文件的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 浅谈pytorch中torch.max和F.softmax函数的维度解释

    浅谈pytorch中torch.max和F.softmax函数的维度解释

    这篇文章主要介绍了浅谈pytorch中torch.max和F.softmax函数的维度解释,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Python中的getopt函数使用详解

    Python中的getopt函数使用详解

    这篇文章主要介绍了Python中的getopt函数使用详解,是Python进阶学习中的重要知识,需要的朋友可以参考下
    2015-07-07
  • python Sweetviz探索性数据可视化分析库使用特征详解

    python Sweetviz探索性数据可视化分析库使用特征详解

    这篇文章主要为大家介绍了python Sweetviz探索性数据可视化分析库特征使用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • Python 常用内置模块超详细梳理总结

    Python 常用内置模块超详细梳理总结

    模块是一个包含索引你定义的函数和变量的文件,其扩展名为.py。模块可以被其他程序引入,以使用该模块中的函数等功能。这也是使用python标准库的方法
    2022-03-03
  • python随机获取列表中某一元素的方法

    python随机获取列表中某一元素的方法

    本文主要介绍了python随机获取列表中某一元素的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02

最新评论