Pytorch中的modle.train,model.eval,with torch.no_grad解读

 更新时间:2022年12月14日 15:30:41   作者:l8947943  
这篇文章主要介绍了Pytorch中的modle.train,model.eval,with torch.no_grad解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

modle.train,model.eval,with torch.no_grad解读

1. 最近在学习pytorch过程中遇到了几个问题

不理解为什么在训练和测试函数中model.eval(),和model.train()的区别,经查阅后做如下整理

一般情况下,我们训练过程如下:

拿到数据后进行训练,在训练过程中,使用

  • model.train():告诉我们的网络,这个阶段是用来训练的,可以更新参数。

训练完成后进行预测,在预测过程中,使用

  • model.eval(): 告诉我们的网络,这个阶段是用来测试的,于是模型的参数在该阶段不进行更新。

2. 但是为什么在eval()阶段会使用with torch.no_grad()?

查阅相关资料:传送门

with torch.no_grad - disables tracking of gradients in autograd.
model.eval() changes the forward() behaviour of the module it is called upon
       eg, it disables dropout and has batch norm use the entire population statistics

总结一下就是说,在eval阶段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都会进行预测,而使用no_grad则设置让梯度Autograd设置为False(因为在训练中我们默认是True),这样保证了反向过程为纯粹的测试,而不变参数。

另外,参考文档说这样避免每一个参数都要设置,解放了GPU底层的时间开销,在测试阶段统一梯度设置为False

model.eval()与torch.no_grad()的作用

model.eval()

经常在模型推理代码的前面, 都会添加model.eval(), 主要有3个作用:

  • 1.不进行dropout
  • 2.不更新batchnorm的mean 和var 参数
  • 3.不进行梯度反向传播, 但梯度仍然会计算

torch.no_grad()

torch.no_grad的一般使用方法是, 在代码块外面用with torch.no_grad()给包起来。 如下面这样:

with torch.no_grad():
    # your code 

它的主要作用有2个:

  • 1.不进行梯度的计算(当然也就没办法反向传播了), 节约显存和算力
  • 2.dropout和batchnorn还是会正常更新

异同

从上面的介绍中可以非常明确的看出,它们的相同点是一般都用在推理阶段, 但它们的作用是完全不同的, 也没有重叠。 可以一起使用。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pytorch进行上采样的种类实例

    pytorch进行上采样的种类实例

    今天小编就为大家分享一篇pytorch进行上采样的种类实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • matplotlib实现区域颜色填充

    matplotlib实现区域颜色填充

    这篇文章主要为大家详细介绍了matplotlib实现区域颜色填充,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-03-03
  • python中关于xmltodict的使用

    python中关于xmltodict的使用

    这篇文章主要介绍了python中关于xmltodict的使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-03-03
  • Python Pandas处理csv文件常用示例

    Python Pandas处理csv文件常用示例

    Pandas是一个非常强大的数据操作python包,支持各种数据格式,包括CSV文件,本文就来介绍一下Python Pandas处理csv文件常用示例,感兴趣的可以了解一下
    2023-12-12
  • Python ''takes exactly 1 argument (2 given)'' Python error

    Python ''takes exactly 1 argument (2 given)'' Python error

    这篇文章主要介绍了Python 'takes exactly 1 argument (2 given)' Python error的相关资料,需要的朋友可以参考下
    2016-12-12
  • 深入了解Python中字符串格式化工具f-strings的使用

    深入了解Python中字符串格式化工具f-strings的使用

    从Python 3.6版本开始,引入了一种新的字符串格式化机制,即f-strings,它强大且易于使用的字符串格式化方式,本文就来聊聊他的具体使用,希望对大家有所帮助
    2023-05-05
  • 使用python绘制随机地形地图

    使用python绘制随机地形地图

    Python 作为一门功能强大的编程语言,在地图生成方面有着丰富的资源和库,本文将介绍如何使用 Python 中的一些工具和库来绘制随机地形地图,感兴趣的小伙伴可以跟着小编一起来看看
    2024-04-04
  • Pyecharts 中Geo函数常用参数的用法说明

    Pyecharts 中Geo函数常用参数的用法说明

    这篇文章主要介绍了Pyecharts 中Geo函数常用参数的用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-02-02
  • 基于Python实现流星雨效果的绘制

    基于Python实现流星雨效果的绘制

    这篇文章主要为大家介绍了如何利用Python绘制一个浪漫的流星雨效果,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试
    2022-03-03
  • Django 模板中常用的过滤器实现

    Django 模板中常用的过滤器实现

    在模版中,有时候需要对一些数据进行处理以后才能使用。一般在Python中我们是通过函数的形式来完成的。而在模版中,则是通过过滤器来实现的,本文就来介绍一下如何实现
    2021-05-05

最新评论