Pytorch使用DataLoader实现批量加载数据

 更新时间:2024年02月27日 09:47:51   作者:Vic·Tory  
这篇文章主要介绍了Pytorch使用DataLoader实现批量加载数据方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

在进行模型训练时,需要把数据按照固定的形式分批次投喂给模型,在PyTorch中通过torch.utils.data库的DataLoader完成分批次返回数据。

构造DataLoader首先需要一个Dataset数据源,Dataset完成数据的读取并可以返回单个数据,然后DataLoader在此基础上完成数据清洗、打乱等操作并按批次返回数据。

Dataset

PyTorch将数据源分为两种类型:类似Map型(Map-style datasets)和可迭代型(Iterable-style datasets)。

Map风格的数据源可以通过索引idx对数据进行查找:dataset[idx],它需要继承Dataset类,并且重写__getitem__() 方法完成根据索引值获取数据和__len__() 方法返回数据的总长度。

可迭代型可以迭代获取其数据,但没有固定的长度,因此也不能通过下标获得数据,通常用于无法获取全部数据或者流式返回的数据。它继承自IterableDataset类,并且需要实现__iter__()方法来完成对数据集的迭代和返回。

如下所示为自定义的数据源MySet,它完成数据的读取,这里假定为[1, 9] 9个数据,然后重写了__getitem__() 和__len__() 方法

from torch.utils.data import Dataset, DataLoader, Sampler

class MySet(Dataset):
	# 读取数据
    def __init__(self):
        self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
	# 根据索引返回数据
    def __getitem__(self, idx):
        return self.data[idx]
	# 返回数据集总长度
    def __len__(self):
        return len(self.data)

DataLoader

其构造函数如下:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
  • dataset:Dataset类型,从其中加载数据 batch_size:int,可选。每个batch加载多少样本
  • batch_size: 一个批次的数据个数
  • shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
  • sampler:Sampler,可选。获取下一个数据的方法。
  • batch_sampler:获取下一批次数据的方法
  • num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
  • collate_fn:callable,可选,自定义处理数据并返回。
  • pin_memory:bool,可选,True代表将数据Tensor放入CUDA的pin储存
  • drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

Sampler索引

既然DataLoader根据索引值从Dataset中获取数据,那么如何获取一个批次数据的索引,索引值应该如何排列才能实现随机的效果?这就需要Sampler了,它可以对索引进行shuffle操作来打乱顺序,并且根据batch size一次返回指定个数的索引序列。

在初始化DataLoader时通过sampler属性指定获取下一个数据的索引的方法,或者batch_sampler属性指定获取下一个批次数据的索引。

当我们设置DataLoader的shuffle属性为True时,会根据batch_size属性传入的批次大小自动构造sample返回下一个批次的索引。

当我们不启用shuffle属性时,就可以通过batch_sampler属性自定义sample来返回下一批的索引,注意这时候不可用使用 batch_size, shuffle, sampler, 和drop_last属性。

如下所示为自定义MySampler,它继承自Sampler,由传入dataset的长度产生对应的索引,例如上面有9个数据,那么产生索引[0, 8]。

根据批次大小batch_size计算出总批次数,例如当batchsize是3,那么9/3=3,即总共有3个批次。

重写__iter__()方法按批次返回索引,即第一批返回[0, 1, 2],第二批返回[3, 4, 5]以此类推。

__len__()方法返回总的批次数,即3个批次。

class MySampler(Sampler):
    def __init__(self, dataset, batchsize):
        super(Sampler, self).__init__()
        self.dataset = dataset
        self.batch_size = batchsize		# 每一批数据量
        self.indices = range(len(dataset))	# 生成数据集的索引
        self.count = int(len(dataset) / self.batch_size)	# 一共有多少批

    def __iter__(self):
        for i in range(self.count):
            yield self.indices[i * self.batch_size: (i + 1) * self.batch_size]

    def __len__(self):
        return self.count

collate处理数据

当我们拿到数据如果希望进行一些预处理而不是直接返回,这时候就需要collate_fn属性来指定处理和返回数据的方法,如果不指定该属性,默认会将普通的NumPy数组转换为PyTorch的tensor并直接返回。

如下所示为自定义的my_collate()函数,默认传入获得的一个批次的数据data,例如之前返回一批数据[1, 2, 3],这里遍历数据并平方之后放在res数组中返回[1, 4, 9]

def my_collate(data):
    res = []
    for d in data:
        res.append(d ** 2)
    return res

有了上面的索引获取类MySampler和数据处理函数my_collate(),就可以使用DataLoader自定义获取批数据了。

首先DataLoader通过my_sampler返回的索引[0, 1, 2]去dataset拿到数据[1, 2, 3],然后传递给my_collate进行平方操作,然后返回一个批次的结果为[1, 4, 9],一共有三个批次的数据。

dataset = MySet()	# 定义数据集
my_sampler = MySampler(dataset, 3)		# 实例化MySampler

data_loader = DataLoader(dataset, batch_sampler=my_sampler, collate_fn=my_collate)

for data in data_loader:	# 按批次获取数据
    print(data)
'''
[1, 4, 9]
[16, 25, 36]
[49, 64, 81]
'''

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。 

相关文章

  • Django查询优化及ajax编码格式原理解析

    Django查询优化及ajax编码格式原理解析

    这篇文章主要介绍了Django查询优化及ajax编码格式原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-03-03
  • Python基础之Numpy的基本用法详解

    Python基础之Numpy的基本用法详解

    这篇文章主要介绍了Python基础之Numpy的基本用法详解,文中有非常详细的代码示例,对正在学习python基础的小伙伴们有非常好的帮助,需要的朋友可以参考下
    2021-05-05
  • pytest实战技巧之参数化基本用法和多种方式

    pytest实战技巧之参数化基本用法和多种方式

    本文介绍了pytest参数化的基本用法和多种方式,帮助读者更好地使用这个功能,同时,还介绍了一些高级技巧,如动态生成参数名称、参数化的组合和动态生成参数化装饰器,帮助读者更灵活地使用参数化,感兴趣的朋友参考下吧
    2023-12-12
  • Python for循环生成列表的实例

    Python for循环生成列表的实例

    今天小编就为大家分享一篇Python for循环生成列表的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 利用Pytorch实现简单的线性回归算法

    利用Pytorch实现简单的线性回归算法

    今天小编就为大家分享一篇利用Pytorch实现简单的线性回归算法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 对Python 获取类的成员变量及临时变量的方法详解

    对Python 获取类的成员变量及临时变量的方法详解

    今天小编就为大家分享一篇对Python 获取类的成员变量及临时变量的方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • tensorflow实现测试时读取任意指定的check point的网络参数

    tensorflow实现测试时读取任意指定的check point的网络参数

    今天小编就为大家分享一篇tensorflow实现测试时读取任意指定的check point的网络参数,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python字符串对齐、删除字符串不需要的内容以及格式化打印字符

    Python字符串对齐、删除字符串不需要的内容以及格式化打印字符

    这篇文章主要给大家介绍了关于Python字符串对齐、删除字符串不需要的内容以及格式化打印字符的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • Python对象循环引用垃圾回收算法详情

    Python对象循环引用垃圾回收算法详情

    这篇文章主要介绍了Python对象循环引用垃圾回收算法详情,文章围绕主题展开详细的内容戒杀,具有一定的参考价值,感兴趣的小伙伴可以参考一下
    2022-09-09
  • 学习Python需要哪些工具

    学习Python需要哪些工具

    这篇文章主要介绍了学习Python需要哪些工具,帮助大家开始学习python编程,感兴趣的朋友可以了解下
    2020-09-09

最新评论