浅谈pytorch中为什么要用 zero_grad() 将梯度清零

 更新时间:2021年05月31日 14:20:28   作者:小小鼠标0  
这篇文章主要介绍了pytorch中为什么要用 zero_grad() 将梯度清零的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch中为什么要用 zero_grad() 将梯度清零

调用backward()函数之前都要将梯度清零,因为如果梯度不清零,pytorch中会将上次计算的梯度和本次计算的梯度累加。

这样逻辑的好处是,当我们的硬件限制不能使用更大的bachsize时,使用多次计算较小的bachsize的梯度平均值来代替,更方便,坏处当然是每次都要清零梯度。

optimizer.zero_grad()
output = net(input)
loss = loss_f(output, target)
loss.backward()

补充:Pytorch 为什么每一轮batch需要设置optimizer.zero_grad

CSDN上有人写过原因,但是其实写得繁琐了。

根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了。

其实这里还可以补充的一点是,如果不是每一个batch就清除掉原有的梯度,而是比如说两个batch再清除掉梯度,这是一种变相提高batch_size的方法,对于计算机硬件不行,但是batch_size可能需要设高的领域比较适合,比如目标检测模型的训练。

关于这一点可以参考这里

关于backward()的计算可以参考这里

补充:pytorch 踩坑笔记之w.grad.data.zero_()

在使用pytorch实现多项线性回归中,在grad更新时,每一次运算后都需要将上一次的梯度记录清空,运用如下方法:

w.grad.data.zero_()
b.grad.data.zero_() 

但是,运行程序就会报如下错误:

报错,grad没有data这个属性,

原因是,在系统将w的grad值初始化为none,第一次求梯度计算是在none值上进行报错,自然会没有data属性

修改方法:添加一个判断语句,从第二次循环开始执行求导运算

for i in range(100):
    y_pred = multi_linear(x_train)
    loss = getloss(y_pred,y_train)
    if i != 0:
        w.grad.data.zero_()
        b.grad.data.zero_()
    loss.backward()
    w.data = w.data - 0.001 * w.grad.data
    b.data = b.data - 0.001 * b.grad.data

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python Django 添加首页尾页上一页下一页代码实例

    Python Django 添加首页尾页上一页下一页代码实例

    这篇文章主要介绍了Python Django 添加首页尾页上一页下一页代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • Python中Generators教程的实现

    Python中Generators教程的实现

    本文主要介绍了Python中Generators教程的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python multiprocessing多进程原理与应用示例

    Python multiprocessing多进程原理与应用示例

    这篇文章主要介绍了Python multiprocessing多进程原理与应用,结合实例形式详细分析了基于multiprocessing包的多进程概念、原理及相关使用操作技巧,需要的朋友可以参考下
    2019-02-02
  • python 如何做一个识别率百分百的OCR

    python 如何做一个识别率百分百的OCR

    最近在做游戏自动化(测试),也就是游戏脚本了。主要有以下几个需求识别率百分百、速度要快、模型要小,本文就来着手实现它
    2021-05-05
  • python小练习之爬鱿鱼游戏的评价生成词云

    python小练习之爬鱿鱼游戏的评价生成词云

    读万卷书不如行万里路,只学书上的理论是远远不够的,只有在实战中才能获得能力的提升,本篇文章手把手带你用Python爬取热火的鱿鱼游戏评价,大家可以在过程中查缺补漏,提升水平
    2021-10-10
  • PyQt实现异步数据库请求的实战记录

    PyQt实现异步数据库请求的实战记录

    开发软件的时候不可避免要和数据库发生交互,但是有些 SQL 请求非常耗时,如果在主线程中发送请求,可能会造成界面卡顿,本文将介绍一种让数据库请求变得和前端的 ajax 请求一样简单,希望对大家有所帮助
    2023-12-12
  • python处理emoji表情(两个函数解决两者之间的联系)

    python处理emoji表情(两个函数解决两者之间的联系)

    这篇文章主要介绍了python处理emoji表情,主要通过两个函数解决两者之间的联系,本文通过实例代码给大家介绍的非常完美,对python emoji表情的相关知识感兴趣的朋友一起看看吧
    2021-05-05
  • Python中用xlwt制作表格实例讲解

    Python中用xlwt制作表格实例讲解

    在本篇文章里小编给大家整理的是一篇关于Python中用xlwt制作表格实例讲解内容,有兴趣的朋友们可以学习下。
    2020-11-11
  • Python教程之全局变量用法

    Python教程之全局变量用法

    这篇文章主要介绍了Python教程之全局变量用法,结合实例形式分析了Python全局变量的定义、修改等使用方法及注意事项,需要的朋友可以参考下
    2016-06-06
  • 如何在Python中隐藏和加密密码示例详解

    如何在Python中隐藏和加密密码示例详解

    Maskpass是一个类似getpass的Python库,但是具有一些高级功能,比如掩蔽和显示,下面这篇文章主要给大家介绍了关于如何在Python中隐藏和加密密码的相关资料,需要的朋友可以参考下
    2022-02-02

最新评论