pytorch 运行一段时间后出现GPU OOM的问题

 更新时间:2021年06月01日 17:18:54   作者:ASR_THU  
这篇文章主要介绍了pytorch 运行一段时间后出现GPU OOM的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch的dataloader会将数据传到GPU上,这个过程GPU的mem占用会逐渐增加,为了避免GPUmen被无用的数据占用,可以在每个step后用del删除一些变量,也可以使用torch.cuda.empty_cache()释放显存:

del targets, input_k, input_mask
torch.cuda.empty_cache()

这时能观察到GPU的显存一直在动态变化。

但是上述方式不是一个根本的解决方案,因为他受到峰值的影响很大。比如某个batch的数据量明显大于其他batch,可能模型处理该batch时显存会不够用,这也会导致OOM,虽然其他的batch都能顺利执行。

显存的占用跟这几个因素相关:

模型参数量

batch size

一个batch的数据 size

通常我们不希望改变模型参数量,所以只能通过动态调整batch-size,使得一个batch的数据 size不会导致显存OOM:

ilen = int(sorted_data[start][1]['input'][0]['shape'][0])
olen = int(sorted_data[start][1]['output'][0]['shape'][0])
# if ilen = 1000 and max_length_in = 800
# then b = batchsize / 2
# and max(1, .) avoids batchsize = 0
# 太长的句子会被动态改变bsz,单独成一个batch,否则padding的部分就太多了,数据量太大,OOM
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
b = max(1, int(batch_size / (1 + factor)))
#b = batch_size
end = min(len(sorted_data), start + b)
minibatch.append(sorted_data[start:end])
if end == len(sorted_data):
    break
start = end

此外,如何选择一个合适的batchsize也是个很重要的问题,我们可以先对所有数据按照大小(长短)排好序(降序),不进行shuffle,按照64,32,16依次尝试bsz,如果模型在执行第一个batch的时候没出现OOM,那么以后一定也不会出现OOM(因为降序排列了数据,所以前面的batch的数据size最大)。

还有以下问题

pytorch increasing cuda memory OOM 问题

改了点model 的计算方式,然后就 OOM 了,调小了 batch_size,然后发现发现是模型每次迭代都会动态增长 CUDA MEMORY, 在排除了 python code 中的潜在内存溢出问题之后,基本可以把问题定在 pytorch 的图计算问题上了,说明每次迭代都重新生成了一张计算图,然后都保存着在,就 OOM 了。

参考

CUDA memory continuously increases when net(images) called in every iteration

Understanding graphs and state

说是会生成多个计算图:

loss = SomeLossFunction(out) + SomeLossFunction(out)

准备用 sum来避免多次生成计算图的问题:

loss = Variable(torch.sum(torch.cat([loss1, loss2], 0)))

然而,调着调着就好了,和报错前的 code 没太大差别。估计的原因是在pycharm 远程连接服务器的时候 code 的保存版本差异问题,这个也需要解决一下。

还有个多次迭代再计算梯度的问题,类似于 caffe中的iter_size,这个再仔细看看。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 如何使用pytorch实现LocallyConnected1D

    如何使用pytorch实现LocallyConnected1D

    由于LocallyConnected1D是Keras中的函数,为了用pytorch实现LocallyConnected1D并在960×33的数据集上进行训练和验证,本文分步骤给大家介绍如何使用pytorch实现LocallyConnected1D,感兴趣的朋友一起看看吧
    2023-09-09
  • 基于Python实现m3u8视频下载

    基于Python实现m3u8视频下载

    m3u8 是一种基于文本的媒体播放列表文件格式,通常用于指定流媒体播放器播放在线媒体流,本文将利用Python实现m3u8视频下载器,感兴趣的可以了解一下
    2023-05-05
  • 对python中的for循环和range内置函数详解

    对python中的for循环和range内置函数详解

    下面小编就为大家分享一篇对python中的for循环和range内置函数详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法

    Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法

    这篇文章主要介绍了Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法,简单描述了最长公共子序列问题并结合实例形式分析了Python基于回溯法子集树模板获取最长公共子序列的操作步骤与相关注意事项,需要的朋友可以参考下
    2017-09-09
  • python manage.py createsuperuser运行错误问题解决

    python manage.py createsuperuser运行错误问题解决

    这篇文章主要介绍了python manage.py createsuperuser运行错误,本文给大家分享错误复现及解决方案,感兴趣的朋友一起看看吧
    2023-10-10
  • Python中使用多进程来实现并行处理的方法小结

    Python中使用多进程来实现并行处理的方法小结

    本篇文章主要介绍了Python中使用多进程来实现并行处理的方法小结,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-08-08
  • Pycharm 如何一键加引号的方法步骤

    Pycharm 如何一键加引号的方法步骤

    这篇文章主要介绍了Pycharm 如何一键加引号的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-02-02
  • Python实现一个简单的QQ截图

    Python实现一个简单的QQ截图

    大家好,本篇文章主要讲的是Python实现一个简单的QQ截图,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下的相关资料
    2022-02-02
  • 解决Python 遍历字典时删除元素报异常的问题

    解决Python 遍历字典时删除元素报异常的问题

    下面小编就为大家带来一篇解决Python 遍历字典时删除元素报异常的问题。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-09-09
  • Python文件路径名的操作方法

    Python文件路径名的操作方法

    对于文件路径名的操作在编程中是必不可少的,比如说,有时候要列举一个路径下的文件,那么首先就要获取一个路径,再就是路径名的一个拼接问题,通过字符串的拼接就可以得到一个路径名。这篇文章主要介绍了Python中文件路径名的操作,需要的朋友可以参考下
    2019-10-10

最新评论