Pytorch中retain_graph的坑及解决

 更新时间:2023年02月21日 09:08:40   作者:Longlongaaago  
这篇文章主要介绍了Pytorch中retain_graph的坑及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

Pytorch中retain_graph的坑

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_img = Variable(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = Variable(data)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)
 
        netD.zero_grad()
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True) #####
        optimizerD.step()
 
        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        optimizerG.step()
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
 
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        running_results['g_loss'] += g_loss.data[0] * batch_size
        d_loss = 1 - real_out + fake_out
        running_results['d_loss'] += d_loss.data[0] * batch_size
        running_results['d_score'] += real_out.data[0] * batch_size
        running_results['g_score'] += fake_out.data[0] * batch_size

也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True)  让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。

但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。

而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!

Pytorch中有多次backward时需要retain_graph参数

Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错

解决办法

loss.backward(retain_graph=True)

错误使用

  • optimizer.zero_grad() 清空过往梯度;
  • loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
  • loss2.backward(retain_graph=True) 反向传播,计算当前梯度;
  • optimizer.step() 根据梯度更新网络参数

因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)

正确使用

  • optimizer.zero_grad() 清空过往梯度;
  • loss1.backward(retain_graph=True) 反向传播,计算当前梯度;
  • loss2.backward() 反向传播,计算当前梯度;
  • optimizer.step() 根据梯度更新网络参数

最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python调用ffmpeg开源视频处理库,批量处理视频

    Python调用ffmpeg开源视频处理库,批量处理视频

    本文主要介绍了如何用Python调用ffmpeg开源视频处理库,来实现视频批量的处理:水印、背景音乐、剪辑、合并、帧率、速率、分辨率等操作
    2020-11-11
  • Python中字符串转换为列表的常用方法总结

    Python中字符串转换为列表的常用方法总结

    本文将详细介绍Python中将字符串转换为列表的八种常用方法,每种方法都具有其独特的用途和适用场景,文中的示例代码讲解详细,感兴趣的可以了解下
    2023-11-11
  • Python中使用urllib2模块编写爬虫的简单上手示例

    Python中使用urllib2模块编写爬虫的简单上手示例

    这篇文章主要介绍了Python中使用urllib2模块编写爬虫的简单上手示例,文中还介绍到了相关异常处理功能的添加,需要的朋友可以参考下
    2016-01-01
  • Python eval()与exec()函数使用介绍

    Python eval()与exec()函数使用介绍

    exec函数执行的是python语句,没有返回值,eval函数执行的是python表达式,有返回值,exec函数和eval函数都可以传入命名空间作为参数,本文给大家介绍下Python eval()和exec()函数,感兴趣的朋友跟随小编一起看看吧
    2023-01-01
  • Python 数字转化成列表详情

    Python 数字转化成列表详情

    这篇文章主要介绍了Python 数字转化成列表,主要以代码实现了将输入的数字转化成一个列表,输入数字中的每一位按照从左到右的顺序成为列表中的一项。,需要的朋友可以参考下
    2021-11-11
  • 解决Django cors跨域问题

    解决Django cors跨域问题

    这篇文章主要介绍了解决Django cors跨域问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-06-06
  • Pytorch中view函数实例讲解

    Pytorch中view函数实例讲解

    这篇文章主要给大家介绍了关于Pytorch中view函数的相关资料,PyTorch中的.view()函数是一个用于改变张量形状的方法,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2023-09-09
  • python 重命名轴索引的方法

    python 重命名轴索引的方法

    今天小编就为大家分享一篇python 重命名轴索引的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • Python利用Tiler制作专属卡通头像和LOGO

    Python利用Tiler制作专属卡通头像和LOGO

    Tiler是一种使用各种其他较小图像平铺创建新图像的工具,它与其他马赛克工具不同,因为它可以适应多种形状、大小、方向的贴图,称为buil in build。本文就来利用Tiler制作专属卡通头像和LOGO,需要的可以参考一下
    2022-12-12
  • Python 存取npy格式数据实例

    Python 存取npy格式数据实例

    这篇文章主要介绍了Python 存取npy格式数据实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07

最新评论