详解model.train()和model.eval()两种模式的原理与用法

 更新时间:2023年03月23日 17:01:13   作者:想变厉害的大白菜  
这篇文章主要介绍了详解model.train()和model.eval()两种模式的原理与用法,相信很多没有经验的人对此束手无策,那么看完这篇文章一定会对你有所帮助

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

到此这篇关于详解model.train()和model.eval()两种模式的原理与用法的文章就介绍到这了,更多相关model.train()和model.eval()原理用法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python实现单项链表的最全教程

    Python实现单项链表的最全教程

    单向链表也叫单链表,是链表中最简单的一种形式,它的每个节点包含两个域,一个信息域(元素域)和一个链接域,这个链接指向链表中的下一个节点,而最后一个节点的链接域则指向一个空值,这篇文章主要介绍了Python实现单项链表,需要的朋友可以参考下
    2023-01-01
  • python环境配置方式(服务器+本地)

    python环境配置方式(服务器+本地)

    这篇文章详细介绍了在服务器上安装和配置Anaconda3、TensorFlow、PyTorch等深度学习环境的步骤,包括下载、初始化、创建环境、验证安装以及解决一些常见问题
    2025-01-01
  • Python模块文件结构代码详解

    Python模块文件结构代码详解

    这篇文章主要介绍了Python模块文件结构代码详解,分享了相关代码示例,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友可以参考下
    2018-02-02
  • Python/MySQL实现Excel文件自动处理数据功能

    Python/MySQL实现Excel文件自动处理数据功能

    在没有服务器存储数据,只有excel文件的情况下,如何利用SQL和python实现数据分析和数据自动处理的功能?本文就来和大家聊聊解决办法
    2023-02-02
  • 解决python -m pip install --upgrade pip 升级不成功问题

    解决python -m pip install --upgrade pip 升级不成功问题

    这篇文章主要介绍了python -m pip install --upgrade pip 解决升级不成功问题,需要的朋友可以参考下
    2020-03-03
  • Python画笔的属性及用法详解

    Python画笔的属性及用法详解

    在本篇文章里小编给大家分享的是一篇关于Python画笔的属性及用法内容,有需要的朋友们可以学习下。
    2021-03-03
  • 解决使用Spyder IDE时matplotlib绘图的显示问题

    解决使用Spyder IDE时matplotlib绘图的显示问题

    这篇文章主要介绍了解决使用Spyder IDE时matplotlib绘图的显示问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-04-04
  • 深入解析Python中filter函数的使用

    深入解析Python中filter函数的使用

    在Python中,filter函数是一种内置的高阶函数,它能够接受一个函数和一个迭代器,然后返回一个新的迭代器,本文主要来介绍一下Python中filter函数的具体用法,需要的可以参考一下
    2023-07-07
  • 彻底搞懂 python 中文乱码问题(深入分析)

    彻底搞懂 python 中文乱码问题(深入分析)

    现在有的小伙伴为了躲避中文乱码的问题甚至代码中不使用中文,注释和提示都用英文,我曾经也这样干过,但这并不是解决问题,而是逃避问题,今天我们一起彻底解决 Python 中文乱码的问题
    2020-02-02
  • python实现文本界面网络聊天室

    python实现文本界面网络聊天室

    这篇文章主要为大家详细介绍了python实现文本界面网络聊天室,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-12-12

最新评论