Python中的Dataset和Dataloader详解

 更新时间:2023年07月29日 08:53:28   作者:菜菜01  
这篇文章主要介绍了Python中的Dataset和Dataloader详解,DataLoader与DataSet是PyTorch数据读取的核心,是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练,需要的朋友可以参考下

Dataset,Dataloader是什么?

  • Dataset:负责可被Pytorch使用的数据集的创建
  • Dataloader:向模型中传递数据

为什么要了解Dataloader

​ 因为你的神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。

因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。

​ 通常,我们在默认或知名数据集(如 MNIST 或 CIFAR)上训练神经网络,可以轻松地实现预测和分类类型问题的超过 90% 的准确度。

但是那是因为这些数据集组织整齐且易于预处理。

但是处理自己的数据集时,我们常常无法达到这样高的准确率

Dataloader 的使用

载入相关类

from torch.utils.data import Dataloader

设置相关参数

from torch.utils.data import DataLoader
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )
"""
dataset:是数据集
batch_size:是指一次迭代中使用的训练样本数。通常我们将数据分成训练集和测试集,并且我们可能有不同的批量大小。
shuffle:是传递给 DataLoader 类的另一个参数。该参数采用布尔值(真/假)。如果 shuffle 设置为 True,则所有样本都被打乱并分批加载。否则,它们会被一个接一个地发送,而不会进行任何洗牌。
num_workers:允许多处理来增加同时运行的进程数
collate_fn:合并数据集
pin_memory:锁页内存:将张量固定在内存中
"""

以minist为例子

# Import MNIST
from torchvision.datasets import MNIST
# Download and Save MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True)
# Print Data
print(data_train)
print(data_train[12])
#Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)

现在让尝试提取元组,其中第一个值对应于图像,第二个值对应于其各自的标签。

下面是代码片段:

import matplotlib.pyplot as plt
random_image = data_train[0][0]
random_image_label = data_train[0][1]
# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

让我们使用 DataLoader 类来加载数据集,如下所示。

import torch
from torchvision import transforms
data_train = torch.utils.data.DataLoader(
    MNIST(
          '~/mnist_data', train=True, download=True, 
          transform = transforms.Compose([
              transforms.ToTensor()
          ])),
          batch_size=64,
          shuffle=True
          )
for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

这就是我们使用 DataLoader 加载简单数据集的方式。 但是,我们不能总是对每个数据集都依赖已经有的数据集,要是自己的数据集怎么办

定义自己的数据集

我们将创建一个由数字和文本组成的简单自定义数据集

先介绍两个方法

#__getitem__() 方法通过索引返回数据集中选定的样本。
#__len__() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则 len 方法应返回 1,00,000。
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

​ 创建自定义数据集并不复杂,但作为加载数据的典型过程的附加步骤,有必要构建一个接口以获得良好的抽象(至少可以说是一个很好的语法糖)。

现在我们将创建一个包含数字及其平方值的新数据集。 让我们将数据集称为 SquareDataset。 其目的是返回 [a,b] 范围内的值的平方。

下面是相关代码:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
class SquareDataset(Dataset):
     def __init__(self, a=0, b=1):
         super(Dataset, self).__init__()
         assert a <= b
         self.a = a
         self.b = b
     def __len__(self):
         return self.b - self.a + 1
     def __getitem__(self, index):
        assert self.a <= index <= self.b
        return index, index**2
data_train = SquareDataset(a=1,b=64)
data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
print(len(data_train))

​ 在上面的代码块中,我们创建了一个名为 SquareDataset 的 Python 类,它继承了 PyTorch 的 Dataset 类。

接下来,我们调用了一个 init() 构造函数,其中 a 和 b 分别被初始化为 0 和 1。 超类用于从继承的 Dataset 类中访问 len 和 get_item 方法。

接下来我们使用 assert 语句来检查 a 是否小于或等于 b,因为我们想要创建一个数据集,其中值将位于 a 和 b 之间。

​ 然后,我们使用 SquareDataset 类创建了一个数据集,其中数据值的范围为 1 到 64。我们将其加载到名为 data_train 的变量中。

最后,Dataloader 类在 data_train_loader 中存储的数据上创建了一个迭代器,batch_size 初始化为 64,shuffle 设置为 True。

如何使用transform

​ 当你学会怎么定义自己的数据集的时候,你可能会想要更近 一步的操作,对于你自己的数据集进行剪切或者变换

​ 以CIFAR10为例子

  • 将所有图像调整为 32×32
  • 对图像应用中心裁剪变换
  • 将裁剪后的图像转换为张量
  • 标准化图像

导入必要的模块

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

接下来,我们将定义一个名为 transforms 的变量,我们在其中按顺序编写所有预处理步骤。我们使用 Compose 类将所有转换操作链接在一起。

transform = transforms.Compose([
    # resize
    transforms.Resize(32),
    # center-crop
    transforms.CenterCrop(32),
    # to-tensor
    transforms.ToTensor(),
    # normalize
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
"""
resize:此调整大小转换将所有图像转换为定义的大小。在这种情况下,我们要将所有图像的大小调整为 32×32。因此,我们将 32 作为参数传递。
center-crop:接下来,我们使用 CenterCrop 变换裁剪图像。 我们发送的参数也是分辨率/大小,但由于我们已经将图像大小调整为 32x32,因此图像将与此裁剪中心对齐。 这意味着图像将从中心裁剪 32 个单位(垂直和水平)。
to-tensor:我们使用 ToTensor() 方法将图像转换为张量数据类型。
normalize:这将张量中的所有值归一化,使它们位于 0.5 和 1 之间。
"""

在下一步中,在执行我们刚刚定义的转换之后,我们将使用 trainloader 将 CIFAR 数据集加载到训练集中。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False)

到此这篇关于Python中的Dataset和Dataloader详解的文章就介绍到这了,更多相关Dataset和Dataloader详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python绘图库之pyqtgraph的用法详解

    Python绘图库之pyqtgraph的用法详解

    PyQtGraph建立在Qt QGraphicsScene的原生库,可提供更好更高性能绘图能力,特别是对于实时数据,可以提供交互性和使用Qt图形小部件轻松自定义绘图的能力。本文就来解释一下pyqtgraph的用法,需要的可以收藏一下
    2022-12-12
  • python实现canny边缘检测

    python实现canny边缘检测

    本文主要讲解了canny边缘检测原理:计算梯度幅值和方向、根据角度对幅值进行非极大值抑制、用双阈值算法检测和连接边缘以及python 实现
    2020-09-09
  • Python利用myqr库创建自己的二维码

    Python利用myqr库创建自己的二维码

    这篇文章主要给大家介绍了关于Python利用myqr库创建自己的二维码的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • 用python制作个视频下载器

    用python制作个视频下载器

    这篇文章主要介绍了如何用python制作个视频下载器,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2021-02-02
  • python 遍历磁盘目录的三种方法

    python 遍历磁盘目录的三种方法

    这篇文章主要介绍了python 遍历磁盘目录的三种方法,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-04-04
  • Python操作json的方法实例分析

    Python操作json的方法实例分析

    这篇文章主要介绍了Python操作json的方法,结合实例形式简单分析了Python针对json数据使用解码loads()和编码dumps()相关操作技巧,需要的朋友可以参考下
    2018-12-12
  • 总结Python使用过程中的bug

    总结Python使用过程中的bug

    今天给大家带来的是关于Python的相关知识,文章围绕着Python使用过程中的bug展开,文中有非常详细的介绍,需要的朋友可以参考下
    2021-06-06
  • Python并发编程之进程间通信原理及实现解析

    Python并发编程之进程间通信原理及实现解析

    这篇文章主要为大家介绍了Python并发编程之进程间通信原理及实现解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • Python类型注解必备利器typing模块全面解读

    Python类型注解必备利器typing模块全面解读

    在Python 3.5版本后引入的typing模块为Python的静态类型注解提供了支持,这个模块在增强代码可读性和维护性方面提供了帮助,本文将深入探讨typing模块,介绍其基本概念、常用类型注解以及使用示例,以帮助读者更全面地了解和应用静态类型注解
    2024-01-01
  • Python利用tkinter和socket实现端口扫描

    Python利用tkinter和socket实现端口扫描

    这篇文章主要为大家详细介绍了Python如何利用tkinter和socket实现端口扫描功能,文中的示例代码讲解详细,感兴趣的小伙伴可以尝试一下
    2022-12-12

最新评论