Pytorch Dataset,TensorDataset,Dataloader,Sampler关系解读

 更新时间:2023年09月11日 16:45:34   作者:czg792845236  
这篇文章主要介绍了Pytorch Dataset,TensorDataset,Dataloader,Sampler关系,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

Dataloader

Dataloader是数据加载器,组合数据集和采样器,并在数据集上提供单线程或多线程的迭代器。

所以Dataloader的参数必然需要指定数据集Dataset和采样器Sampler。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
  • dataset (Dataset) – 数据集。
  • batch_size (int, optional) – 每个batch加载样本数。
  • shuffle (bool, optional) – True则打乱数据.
  • sampler (Sampler, optional) – 采样器,如指定则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载
  • collate_fn (callable, optional) – 获取batch数据的回调函数,也就是说可以在这个函数中修改batch的形式
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。

Dataset和TensorDataset

所有其他数据集都应该进行子类化。所有子类应该override __len__ __getitem__ ,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

TensorDataset是Dataset的子类,已经复写了 __len__ __getitem__ 方法,只要传入张量即可,它通过第一个维度进行索引。

TensorDataset示例

所以TensorDataset说白了就是将输入的tensors捆绑在一起,然后 __len__ 是任何一个tensor的维度, __getitem__ 表示每个tensor取相同的索引,然后将这个结果组成一个元组,源码如下,要好好理解它通过第一个维度进行索引的意思(针对tensors里面的每一个tensor而言)。

class TensorDataset(Dataset):
	def __init__(self,*tensors):
		assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors)
		self.tensors = tensors
	def __getitem__(self,index):
		return tuple(tensor[index] for tensor in self.tensors)
	def __len__(self):
		return self.tensors[0].size(0)

Sampler和RandomSampler

Sampler与Dataset类似,是采样器的基础类。

每个采样器子类必须提供一个 __iter__ 方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的 __len__ 方法。

所以Sampler必然是关于索引的迭代器,也就是它的输出是索引。

而RandomSampler与TensorDataset类似,RandomSamper已经实现了 __iter__ __len__ 方法,只需要传入数据集即可。

在这里插入图片描述

猜想理解RandomSampler的实现方式,考虑到这个类实现需要传入Dataset,所以 __len__ 就是Dataset的 __len__ ,然后 __iter__ 就可以随便搞一个随机函数对range(length)随机即可。

综合示例

结合TensorDataset和RandomSampler使用Dataloader

这里即可理解Dataloader这个数据加载器其实就是组合数据集和采样器的组合。

所以那就是先根据Sampler随机拿到一个索引,再用这个索引到Dataset中取tensors里每个tensor对应索引的数据来组成一个元组。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python删除目录的三种方法

    python删除目录的三种方法

    本文主要介绍了python删除目录的三种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2024-12-12
  • Python的历史与优缺点整理

    Python的历史与优缺点整理

    在本篇文章里小编给大家分享的是关于Python优缺点及基础知识点整理内容,有需要的朋友们可以参考下。
    2020-05-05
  • 举例讲解Python设计模式编程的代理模式与抽象工厂模式

    举例讲解Python设计模式编程的代理模式与抽象工厂模式

    这篇文章主要介绍了Python编程的代理模式与抽象工厂模式,文中举了两个简单的小例子来说明这两种设计模式的思路在Python编程中的体现,需要的朋友可以参考下
    2016-01-01
  • Python操作PDF文件之实现A3页面转A4

    Python操作PDF文件之实现A3页面转A4

    这篇文章主要为大家详细介绍了Python操作PDF文件之实现A3页面转A4功能的相关资料,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2022-11-11
  • python自动化测试用例全对偶组合与全覆盖组合比较

    python自动化测试用例全对偶组合与全覆盖组合比较

    这篇文章主要为大家介绍了python自动化测试用例全对偶组合与全覆盖组合比较,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-06-06
  • Python使用tarfile模块实现免费压缩解压

    Python使用tarfile模块实现免费压缩解压

    Python自带的tarfile模块可以方便读取tar归档文件,厉害的是可以处理使用gzip和bz2压缩归档文件tar.gz和tar.bz2,这篇文章主要介绍了Python使用tarfile模块实现免费压缩解压,需要的朋友可以参考下
    2024-03-03
  • Python中列表和元组的使用方法和区别详解

    Python中列表和元组的使用方法和区别详解

    这篇文章主要介绍了Python中列表和元组的使用方法和区别详解的相关资料,需要的朋友可以参考下
    2016-07-07
  • Python库 Django 的简介、安装、用法入门教程

    Python库 Django 的简介、安装、用法入门教程

    Django是Python最流行的Web框架之一,它帮助开发者快速、高效地构建功能强大的Web应用程序,接下来我们将从简介、安装到用法详解,全方位解析Django的世界,感兴趣的朋友跟随小编一起看看吧
    2025-08-08
  • Python Pexpect库的简单使用方法

    Python Pexpect库的简单使用方法

    这篇文章主要介绍了Python Pexpect库的简单使用方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-01-01
  • Python pickle模块的使用指南

    Python pickle模块的使用指南

    Python pickle模块用于对象序列化与反序列化,支持dump/load方法及自定义类,需注意安全风险,建议在受控环境中使用,适用于模型持久化、缓存及跨进程通信,选择高效协议版本以提升性能
    2025-09-09

最新评论