pytorch中retain_graph==True的作用说明

 更新时间:2023年02月21日 08:45:56   作者:撒旦即可  
这篇文章主要介绍了pytorch中retain_graph==True的作用说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch retain_graph==True的作用说明

总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。 

retain_graph参数的作用

官方定义:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。

但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

具体看一个例子理解

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。

然后我们对两个output执行backward。

import torch
x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
y = x ** 2
z = y * 4
print(x)
print(y)
print(z)
loss1 = z.mean()
loss2 = z.sum()
print(loss1,loss2)
loss1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
print(loss1,loss2)
loss2.backward()    # 这时会引发错误

程序正常执行到第12行,所有的变量正常保存。

但是在第13行报错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

分析:计算节点数值保存了,但是计算图x-y-z结构被释放了,而计算loss2的backward仍然试图利用x-y-z的结构,因此会报错。

因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

正确的代码应当把第11行以及之后改成

  • 1 # 假如你需要执行两次backward,先执行第一个的backward,再执行第二个backward
  • 2 loss1.backward(retain_graph=True)# 这里参数表明保留backward后的中间参数。
  • 3 loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
  • 4  #如果是在训练网络optimizer.step() # 更新参数

create_graph参数比较简单,参考官方定义:

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

Pytorch retain_graph=True错误信息

(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)

具有多个loss值

retain_graph设置True,一般多用于两次backward

# 假如有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True) # 这样计算图就不会立即释放
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

retain_graph设置True后一定要知道释放,否则显卡会占用越来越多,代码速度也会跑的越来越慢。

有的时候我明明仅有一个模型的也会出现这种错误

第一种是输入的原因。

// Example
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
x_train, y_train = x[:70], y[:70]
x_val, y_val = x[70:], y[70:]

for epoch in range(n_epochs):
    ...
    prediction = model(x_train)
    loss.backward()
    ...

在多次循环的过程中,input的梯度没有清除,而且我们也不需要计算输入的梯度,因此将x的require_grad设置为False就可以解决问题。

第二种是我在训练LSTM时候发现的。

class LSTMpred(nn.Module):
    def __init__(self, input_size, hidden_dim):
        self.hidden = self.init_hidden()
       ...
    def init_hidden(self):    #这里我们是需要个隐层参数的
        return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True),
                torch.zeros(1, 1, self.hidden_dim, requires_grad=True))
    def forward(self, seq):
        ...

这里面的self.hidden我们在每一次训练的时候都要重新初始化隐层参数:

for epoch in range(Epoch):
    ...
    model.hidden = model.init_hidden()
    modout = model(seq)
    ...

3. 我的看法

其实,想想这几种情况都是一回事,都是网络在反向传播中不允许多个backward(),也就是梯度下降反馈的时候,有多个循环过程中共用了同一个需要计算梯度的变量,在前一个循环清除梯度后,后面一个循环过程就会在这个变量上栽跟头(个人想法)。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • K-近邻算法的python实现代码分享

    K-近邻算法的python实现代码分享

    这篇文章主要介绍了K-近邻算法的python实现代码分享,具有一定借鉴价值,需要的朋友可以参考下。
    2017-12-12
  • python制作爬虫爬取京东商品评论教程

    python制作爬虫爬取京东商品评论教程

    本文是继前2篇Python爬虫系列文章的后续篇,给大家介绍的是如何使用Python爬取京东商品评论信息的方法,并根据数据绘制成各种统计图表,非常的细致,有需要的小伙伴可以参考下
    2016-12-12
  • Python库textract提取各种文档类型中文本数据

    Python库textract提取各种文档类型中文本数据

    Python的textract库是一个强大的工具,它可以从各种文档类型中提取文本数据,无论是PDF、Word文档、图片还是其他格式的文件,textract都可以轻松地将文本提取出来,本文将详细介绍textract的功能和用法,并提供丰富的示例代码来帮助大家深入了解
    2024-01-01
  • python zip文件 压缩

    python zip文件 压缩

    看了我前面的一系列文章,不知道你会不会觉得python是无所不能的,我现在就这感觉!如何用python进行文件压缩呢
    2008-12-12
  • 一文详解Python中的Map,Filter和Reduce函数

    一文详解Python中的Map,Filter和Reduce函数

    这篇文章主要介绍了一文详解Python中的Map,Filter和Reduce函数,本文重点介绍Python中的三个特殊函数Map,Filter和Reduce,以及如何使用它们进行代码编程
    2022-08-08
  • Python实现爬虫IP负载均衡和高可用集群的示例代码

    Python实现爬虫IP负载均衡和高可用集群的示例代码

    做大型爬虫项目经常遇到请求频率过高的问题,这里需要说的是使用爬虫IP可以提高抓取效率,本文主要介绍了Python实现爬虫IP负载均衡和高可用集群的示例代码,感兴趣的可以了解一下
    2023-12-12
  • python实现吃苹果小游戏

    python实现吃苹果小游戏

    这篇文章主要为大家详细介绍了python实现吃苹果小游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-03-03
  • python3压缩和解压文件案例总结

    python3压缩和解压文件案例总结

    压缩和解压缩是日常常用的操作,不管是windows上图形界面的操作,还是linux上用命令来进行压缩解压缩,总的而言都还是比较方便的,本文通过案例代码讲解了python3压缩和解压文件的方法,
    2023-02-02
  • Python Pandas处理csv文件常用示例

    Python Pandas处理csv文件常用示例

    Pandas是一个非常强大的数据操作python包,支持各种数据格式,包括CSV文件,本文就来介绍一下Python Pandas处理csv文件常用示例,感兴趣的可以了解一下
    2023-12-12
  • python实现plt x轴坐标按1刻度显示

    python实现plt x轴坐标按1刻度显示

    这篇文章主要介绍了python实现plt x轴坐标按1刻度显示,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07

最新评论