Pytorch 如何加速Dataloader提升数据读取速度

 更新时间:2021年05月28日 09:23:29   作者:MKFMIKU  
这篇文章主要介绍了Pytorch 加速Dataloader提升数据读取速度的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python调用staf自动化框架的方法

    python调用staf自动化框架的方法

    今天小编就为大家分享一篇python调用staf自动化框架的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • tensorflow 使用flags定义命令行参数的方法

    tensorflow 使用flags定义命令行参数的方法

    本篇文章主要介绍了tensorflow 使用flags定义命令行参数的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-04-04
  • 解决Jupyter 文件路径的问题

    解决Jupyter 文件路径的问题

    这篇文章主要介绍了解决Jupyter 文件路径的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • python 梯度法求解函数极值的实例

    python 梯度法求解函数极值的实例

    今天小编就为大家分享一篇python 梯度法求解函数极值的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • python中的下划线多种用法总结

    python中的下划线多种用法总结

    在 Python 中,下划线(underscore)有多种用法,它在不同的上下文中可以扮演不同的角色,本文将介绍python中的下划线用法总结,感兴趣的朋友一起看看吧
    2024-05-05
  • python读取和保存图片5种方法对比

    python读取和保存图片5种方法对比

    为大家分享一下python读取和保存图片5种方法与比较,python中对象之间的赋值是按引用传递的,如果需要拷贝对象,需要用到标准库中的copy模块
    2018-09-09
  • PyCharm+PyQt5+QtDesigner配置详解

    PyCharm+PyQt5+QtDesigner配置详解

    这篇文章主要介绍了PyCharm+PyQt5+QtDesigner配置详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • Python简单实现Base64编码和解码的方法

    Python简单实现Base64编码和解码的方法

    这篇文章主要介绍了Python简单实现Base64编码和解码的方法,结合具体实例形式分析了Python实现base64编码解码相关函数与使用技巧,需要的朋友可以参考下
    2017-04-04
  • 使用python提取PowerPoint幻灯片中表格并保存到文本及Excel文件

    使用python提取PowerPoint幻灯片中表格并保存到文本及Excel文件

    owerPoint作为广泛使用的演示工具,常被用于展示各类数据报告和分析结果,其中,表格以其直观性和结构性成为阐述数据关系的不二之选,本文将介绍如何使用Python来提取PowerPoint幻灯片中的表格,并将表格数据写入文本文件以及Excel文件,需要的朋友可以参考下
    2024-06-06
  • Python中的迭代器漫谈

    Python中的迭代器漫谈

    这篇文章主要介绍了Python中的迭代器漫谈,本文主要讲解range函数和xrange函数性能区别,需要的朋友可以参考下
    2015-02-02

最新评论