Pytorch中TensorDataset与DataLoader的使用方式

 更新时间:2023年09月09日 08:50:11   作者:Arxan_hjw  
这篇文章主要介绍了Pytorch中TensorDataset与DataLoader的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

TensorDataset与DataLoader的使用

TensorDataset

TensorDataset本质上与python zip方法类似,对数据进行打包整合。

官方文档说明:

**Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.*

Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.

该类通过每一个 tensor 的第一个维度进行索引。

因此,该类中的 tensor 第一维度必须相等。

import torch
from torch.utils.data import TensorDataset
# a的形状为(4*3)
a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
# b的第一维与a相同
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
print(train_data[0:4])

输出结果如下:

(tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]]), tensor([1, 2, 3, 4]))

DataLoader

DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存。

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
b = torch.tensor([1,2,3,4])
train_data = TensorDataset(a,b)
data = DataLoader(train_data, batch_size=2, shuffle=True)
for i, j in enumerate(data):
    x, y = j
    print(' batch:{0} x:{1}  y: {2}'.format(i, x, y))

输出:

 batch:0 x:tensor([[1, 1, 1],
        [2, 2, 2]])  y: tensor([1, 2])
 batch:1 x:tensor([[4, 4, 4],
        [3, 3, 3]])  y: tensor([4, 3])

Pytorch Dataset,TensorDataset,Dataloader,Sampler关系

Dataloader

Dataloader是数据加载器,组合数据集和采样器,并在数据集上提供单线程或多线程的迭代器。

所以Dataloader的参数必然需要指定数据集Dataset和采样器Sampler。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
  • dataset (Dataset) – 数据集。
  • batch_size (int, optional) – 每个batch加载样本数。
  • shuffle (bool, optional) – True则打乱数据.
  • sampler (Sampler, optional) – 采样器,如指定则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载
  • collate_fn (callable, optional) – 获取batch数据的回调函数,也就是说可以在这个函数中修改batch的形式
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。

Dataset和TensorDataset

所有其他数据集都应该进行子类化。所有子类应该override __len__ __getitem__ ,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

TensorDataset是Dataset的子类,已经复写了 __len__ __getitem__ 方法,只要传入张量即可,它通过第一个维度进行索引。

所以TensorDataset说白了就是将输入的tensors捆绑在一起,然后 __len__ 是任何一个tensor的维度, __getitem__ 表示每个tensor取相同的索引,然后将这个结果组成一个元组,源码如下,要好好理解它通过第一个维度进行索引的意思(针对tensors里面的每一个tensor而言)。

class TensorDataset(Dataset):
	def __init__(self,*tensors):
		assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors)
		self.tensors = tensors
	def __getitem__(self,index):
		return tuple(tensor[index] for tensor in self.tensors)
	def __len__(self):
		return self.tensors[0].size(0)

Sampler和RandomSampler

Sampler与Dataset类似,是采样器的基础类。

每个采样器子类必须提供一个 __iter__ 方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的 __len__ 方法。

所以Sampler必然是关于索引的迭代器,也就是它的输出是索引。

而RandomSampler与TensorDataset类似,RandomSamper已经实现了 __iter__ __len__ 方法,只需要传入数据集即可。

猜想理解RandomSampler的实现方式,考虑到这个类实现需要传入Dataset,所以 __len__ 就是Dataset的 __len__ ,然后 __iter__ 就可以随便搞一个随机函数对range(length)随机即可。

综合示例

结合TensorDataset和RandomSampler使用Dataloader

这里即可理解Dataloader这个数据加载器其实就是组合数据集和采样器的组合。所以那就是先根据Sampler随机拿到一个索引,再用这个索引到Dataset中取tensors里每个tensor对应索引的数据来组成一个元组。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pandas中的DataFrame数据遍历解读

    pandas中的DataFrame数据遍历解读

    这篇文章主要介绍了pandas中的DataFrame数据遍历解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12
  • Python数据分析之堆叠数组函数示例总结

    Python数据分析之堆叠数组函数示例总结

    这篇文章主要为大家介绍了Python数据分析之堆叠数组函数示例总结,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-02-02
  • Python基础教程之if判断,while循环,循环嵌套

    Python基础教程之if判断,while循环,循环嵌套

    这篇文章主要介绍了Python基础教程之if判断,while循环,循环嵌套 的相关知识,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-04-04
  • Python实现RabbitMQ6种消息模型的示例代码

    Python实现RabbitMQ6种消息模型的示例代码

    这篇文章主要介绍了Python实现RabbitMQ6种消息模型的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-03-03
  • python自动定时任务schedule库的使用方法

    python自动定时任务schedule库的使用方法

    当你需要在 Python 中定期执行任务时,schedule 库是一个非常实用的工具,它可以帮助你自动化定时任务,本文给大家介绍了python自动定时任务schedule库的使用方法,需要的朋友可以参考下
    2024-02-02
  • Django框架序列化与反序列化操作详解

    Django框架序列化与反序列化操作详解

    这篇文章主要介绍了Django框架序列化与反序列化操作,结合实例形式详细分析了Django框架Serializer类操作对象序列化及反序列化相关实现技巧,需要的朋友可以参考下
    2019-11-11
  • 利用python微信库itchat实现微信自动回复功能

    利用python微信库itchat实现微信自动回复功能

    最近发现了一个特别好玩的Python 微信库itchat,可以实现自动回复等多种功能,下面这篇文章主要给大家介绍了利用python微信库itchat实现微信自动回复功能的相关资料,需要的朋友可以参考学习,下面来一起看看吧。
    2017-05-05
  • Python 中 Kwargs 解析的最佳实践教程

    Python 中 Kwargs 解析的最佳实践教程

    这篇文章主要介绍了Python中Kwargs解析的最佳实践,使用 kwargs,我们可以编写带有任意数量关键字参数的函数,当我们想为函数提供灵活的接口时,这会很有用,需要的朋友可以参考下
    2023-06-06
  • Python 中数组和数字相乘时的注意事项说明

    Python 中数组和数字相乘时的注意事项说明

    这篇文章主要介绍了Python 中数组和数字相乘时的注意事项说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-05-05
  • 详解pandas数据合并与重塑(pd.concat篇)

    详解pandas数据合并与重塑(pd.concat篇)

    这篇文章主要介绍了详解pandas数据合并与重塑(pd.concat篇),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07

最新评论