PyTorch中torch.utils.data.DataLoader实例详解

 更新时间:2022年09月27日 09:46:31   作者:进击的程小白  
torch.utils.data.DataLoader主要是对数据进行batch的划分,下面这篇文章主要给大家介绍了关于PyTorch中torch.utils.data.DataLoader的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下

1、dataset:(数据类型 dataset)

输入的数据类型,这里是原始数据的输入。PyTorch内也有这种数据结构。

2、batch_size:(数据类型 int)

批训练数据量的大小,根据具体情况设置即可(默认:1)。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。每次是随机读取大小为batch_size。如果dataset中的数据个数不是batch_size的整数倍,这最后一次把剩余的数据全部输出。若想把剩下的不足batch size个的数据丢弃,则将drop_last设置为True,会将多出来不足一个batch的数据丢弃。

3、shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

4、collate_fn:(数据类型 callable,没见过的类型)

将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。

6、sampler:(数据类型 Sampler)

采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、num_workers:(数据类型 Int)

工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。

8、pin_memory:(数据类型 bool)

内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

9、drop_last:(数据类型 bool)

丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。

10、timeout:(数据类型 numeric)

超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

11、worker_init_fn(数据类型 callable,没见过的类型)

子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

对batch_size举例分析:

"""
    批训练,把数据变成一小批一小批数据进行训练。
    DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
 
BATCH_SIZE = 5
 
x = torch.linspace(1, 11, 11)
y = torch.linspace(11, 1, 11)
print(x)
print(y)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # num_workers=2,
)
 
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
 
if __name__ == '__main__':
    show_batch()

输出为:

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])
tensor([11., 10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.])
steop:0, batch_x:tensor([ 3.,  2.,  8., 11.,  1.]), batch_y:tensor([ 9., 10.,  4.,  1., 11.])
steop:1, batch_x:tensor([ 5.,  6.,  7.,  4., 10.]), batch_y:tensor([7., 6., 5., 8., 2.])
steop:2, batch_x:tensor([9.]), batch_y:tensor([3.])
steop:0, batch_x:tensor([ 9.,  7., 10.,  2.,  4.]), batch_y:tensor([ 3.,  5.,  2., 10.,  8.])
steop:1, batch_x:tensor([ 5., 11.,  3.,  6.,  8.]), batch_y:tensor([7., 1., 9., 6., 4.])
steop:2, batch_x:tensor([1.]), batch_y:tensor([11.])
steop:0, batch_x:tensor([10.,  5.,  7.,  4.,  2.]), batch_y:tensor([ 2.,  7.,  5.,  8., 10.])
steop:1, batch_x:tensor([3., 9., 1., 8., 6.]), batch_y:tensor([ 9.,  3., 11.,  4.,  6.])
steop:2, batch_x:tensor([11.]), batch_y:tensor([1.])
 
Process finished with exit code 0

若drop_last=True

"""
    批训练,把数据变成一小批一小批数据进行训练。
    DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
 
BATCH_SIZE = 5
 
x = torch.linspace(1, 11, 11)
y = torch.linspace(11, 1, 11)
print(x)
print(y)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # num_workers=2,
    drop_last=True,
)
 
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
 
if __name__ == '__main__':
    show_batch()

对应的输出为:

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])
tensor([11., 10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.])
steop:0, batch_x:tensor([ 9.,  2.,  7.,  4., 11.]), batch_y:tensor([ 3., 10.,  5.,  8.,  1.])
steop:1, batch_x:tensor([ 3.,  5., 10.,  1.,  8.]), batch_y:tensor([ 9.,  7.,  2., 11.,  4.])
steop:0, batch_x:tensor([ 5., 11.,  6.,  1.,  2.]), batch_y:tensor([ 7.,  1.,  6., 11., 10.])
steop:1, batch_x:tensor([ 3.,  4., 10.,  8.,  9.]), batch_y:tensor([9., 8., 2., 4., 3.])
steop:0, batch_x:tensor([10.,  4.,  9.,  8.,  7.]), batch_y:tensor([2., 8., 3., 4., 5.])
steop:1, batch_x:tensor([ 6.,  1., 11.,  2.,  5.]), batch_y:tensor([ 6., 11.,  1., 10.,  7.])
 
Process finished with exit code 0

总结

到此这篇关于PyTorch中torch.utils.data.DataLoader的文章就介绍到这了,更多相关PyTorch torch.utils.data.DataLoader内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 详解pandas df.iloc[]的典型用法

    详解pandas df.iloc[]的典型用法

    本文主要介绍了详解pandas df.iloc[]的典型用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-08-08
  • 浅析form标签中的GET和POST提交方式区别

    浅析form标签中的GET和POST提交方式区别

    在HTML中,form表单的作用是收集标签中的内容<form>...</form> 中间可以由访问者添加类似于文本,选择,或者一些控制模块等等.然后这些内容将会被送到服务端
    2021-09-09
  • 在CentOS6上安装Python2.7的解决方法

    在CentOS6上安装Python2.7的解决方法

    在CentOS6上yum安装工具是基于Python2.6.6的,所以在CentOS6上默认安装的是Python2.6.6,因为要在服务器系统为CentOS6上部署生产环境,但是代码都是基于Python2.7写的,所有遇到了问题,下面通过本文给大家介绍下在CentOS6上安装Python2.7的解决方法,一起看看吧
    2018-01-01
  • Python六大开源框架对比

    Python六大开源框架对比

    在这篇文章里,我们将为Python Web开发者回顾基于Python的6大Web应用框架。无论你是出于爱好还是需求,这六大框架都可能会成为你工作上不错的得力助手。
    2015-10-10
  • Python处理浮点数的实用技巧分享

    Python处理浮点数的实用技巧分享

    四舍五入是一种常见的数学操作,它用于将数字舍入到指定的精度,Python 提供了多种方法来实现四舍五入操作,本文将详细介绍这些方法,希望对大家有所帮助
    2024-12-12
  • Python中函数带括号和不带括号的区别及说明

    Python中函数带括号和不带括号的区别及说明

    这篇文章主要介绍了Python中函数带括号和不带括号的区别及说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • 详解在SpringBoot如何优雅的使用多线程

    详解在SpringBoot如何优雅的使用多线程

    这篇文章主要带大家快速了解一下@Async注解的用法,包括异步方法无返回值、有返回值,最后总结了@Async注解失效的几个坑,感兴趣的小伙伴可以了解一下
    2023-02-02
  • python将logging模块封装成单独模块并实现动态切换Level方式

    python将logging模块封装成单独模块并实现动态切换Level方式

    这篇文章主要介绍了python将logging模块封装成单独模块并实现动态切换Level方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • pygame实现雷电游戏雏形开发

    pygame实现雷电游戏雏形开发

    这篇文章主要为大家详细介绍了pygame实现雷电游戏开发代码,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • 教你使用Python连接oracle

    教你使用Python连接oracle

    今天教各位小伙伴怎么用Python连接oracle,文中附带非常详细的图文示例,对正在学习的小伙伴们很有帮助哟,需要的朋友可以参考下
    2021-05-05

最新评论