弄清Pytorch显存的分配机制

 更新时间:2020年12月10日 10:18:22   作者:颀周  
这篇文章主要介绍了Pytorch显存的分配机制的相关资料,帮助大家更好的理解和使用Pytorch,感兴趣的朋友可以了解下

  对于显存不充足的炼丹研究者来说,弄清楚Pytorch显存的分配机制是很有必要的。下面直接通过实验来推出Pytorch显存的分配过程。

  实验实验代码如下:

import torch 
from torch import cuda 

x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') 
print("1", cuda.memory_allocated()/1024**2) 
y = 5 * x 
print("2", cuda.memory_allocated()/1024**2) 
torch.mean(y).backward()   
print("3", cuda.memory_allocated()/1024**2)  
print(cuda.memory_summary())

输出如下:

  代码首先分配3GB的显存创建变量x,然后计算y,再用y进行反向传播。可以看到,创建x后与计算y后分别占显存3GB与6GB,这是合理的。另外,后面通过backward(),计算出x.grad,占存与x一致,所以最终一共占有显存9GB,这也是合理的。但是,输出显示了显存的峰值为12GB,这多出的3GB是怎么来的呢?首先画出计算图:

下面通过列表的形式来模拟Pytorch在运算时分配显存的过程:

  如上所示,由于需要保存反向传播以前所有前向传播的中间变量,所以有了12GB的峰值占存。

  我们可以不存储计算图中的非叶子结点,达到节省显存的目的,即可以把上面的代码中的y=5*x与mean(y)写成一步:

import torch 
from torch import cuda 

x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda') 
print("1", cuda.memory_allocated()/1024**2)  
torch.mean(5*x).backward()   
print("2", cuda.memory_allocated()/1024**2)  
print(cuda.memory_summary())

 占显存量减少了3GB:

以上就是弄清Pytorch显存的分配机制的详细内容,更多关于Pytorch 显存分配的资料请关注脚本之家其它相关文章!

相关文章

  • Python pyecharts绘制柱状图

    Python pyecharts绘制柱状图

    这篇文章主要介绍了Python pyecharts绘制柱状图,文章介绍的柱状/条形图,通过柱形的高度/条形的宽度来表现数据的大小,感兴趣的小伙伴一起进入文章学习更详细内容吧
    2021-12-12
  • Python实现的各种常见分布算法示例

    Python实现的各种常见分布算法示例

    这篇文章主要介绍了Python实现的各种常见分布算法,结合实例形式总结分析了Python常见的各种分布算法相关实现技巧,包括二项分布、离散分布、泊松分布、正态分布、指数分布等算法实现方法,需要的朋友可以参考下
    2018-12-12
  • pytorch visdom安装开启及使用方法

    pytorch visdom安装开启及使用方法

    这篇文章主要介绍了pytorch visdom安装开启及使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Python logging模块原理解析及应用

    Python logging模块原理解析及应用

    这篇文章主要介绍了Python logging模块原理解析及应用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Linux 下 Python 实现按任意键退出的实现方法

    Linux 下 Python 实现按任意键退出的实现方法

    这篇文章主要介绍了Linux 下 Python 实现按任意键退出的实现方法的相关资料,本文介绍的非常详细,具有参考借鉴价值,需要的朋友可以参考下
    2016-09-09
  • Python中flask框架跨域问题的解决方法

    Python中flask框架跨域问题的解决方法

    本文主要介绍了Python中flask框架跨域问题的解决方法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • Python实现Logger打印功能的方法详解

    Python实现Logger打印功能的方法详解

    最近工作中遇到了打印的需求,通过查找相关的资料发现Python中Logger可以很好的实现打印,所以下面这篇文章主要给大家介绍了关于Python如何实现Logger打印功能的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下。
    2017-09-09
  • Python对列表的操作知识点详解

    Python对列表的操作知识点详解

    在本篇文章里小编给大家整理了关于Python对列表的操作知识点总结以及实例代码运用,需要的朋友们跟着学习下。
    2019-08-08
  • 用Python写一个简易版弹球游戏

    用Python写一个简易版弹球游戏

    这篇文章主要介绍了用Python写一个简易版弹球游戏,文中有很多实用代码,对正在学习python的小伙伴们有很大的帮助.需要的朋友可以参考下
    2021-04-04
  • python3.6 如何将list存入txt后再读出list的方法

    python3.6 如何将list存入txt后再读出list的方法

    这篇文章主要介绍了python3.6 如何将list存入txt后再读出list的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07

最新评论