pytorch 实现在测试的时候启用dropout

 更新时间:2021年05月27日 10:07:17   作者:qian99  
这篇文章主要介绍了pytorch 实现在测试的时候启用dropout的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

在这里插入图片描述

二.搭建神经网络

三.训练

在这里插入图片描述

四.对比测试结果

注意:测试过程中,一定要注意模式切换

在这里插入图片描述

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python如何将txt坐标批量打印到原图上

    python如何将txt坐标批量打印到原图上

    这篇文章主要介绍了python如何将txt坐标批量打印到原图上的问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • Python matplotlib修改默认字体的操作

    Python matplotlib修改默认字体的操作

    这篇文章主要介绍了Python matplotlib修改默认字体的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • 使用Keras构造简单的CNN网络实例

    使用Keras构造简单的CNN网络实例

    这篇文章主要介绍了使用Keras构造简单的CNN网络实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 新手入门Python编程的8个实用建议

    新手入门Python编程的8个实用建议

    这篇文章主要介绍了Python编程的8个实用建议,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • CPython 垃圾收集器检测循环引用详解

    CPython 垃圾收集器检测循环引用详解

    这篇文章主要为大家介绍了CPython 垃圾收集器检测循环引用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-10-10
  • Python学习之12个常用基础语法详解

    Python学习之12个常用基础语法详解

    这篇文章主要为大家介绍了12个Python小案例,包含了日常开发中非常实用的语法,快来跟随小编一起学习一下,看看自己都会多少个呢
    2022-02-02
  • PyTorch中torch.load()的用法和应用

    PyTorch中torch.load()的用法和应用

    torch.load()它用于加载由torch.save()保存的模型或张量,本文主要介绍了PyTorch中torch.load()的用法和应用,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • 谈一谈数组拼接tf.concat()和np.concatenate()的区别

    谈一谈数组拼接tf.concat()和np.concatenate()的区别

    今天小编就为大家分享一篇谈谈数组拼接tf.concat()和np.concatenate()的区别,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • Python灰度变换中伽马变换分析实现

    Python灰度变换中伽马变换分析实现

    灰度变换是指根据某种目标条件按一定变换关系逐点改变源图像中每个像素灰度值的方法。目的是改善画质,使图像显示效果更加清晰。图像的灰度变换处理是图像增强处理技术中的一种非常基础、直接的空间域图像处理方法,也是图像数字化软件和图像显示软件的一个重要组成部分
    2022-10-10
  • 从零开始理解如何使用Python搭建智能AI代理

    从零开始理解如何使用Python搭建智能AI代理

    Agentic AI(智能代理)正在悄然改变我们的工作方式,所以这篇文章小编就来和大家简单介绍一下如何使用Python搭建智能AI代理,感兴趣的小伙伴可以了解下
    2025-07-07

最新评论