Pytorch 如何实现常用正则化

 更新时间:2021年05月27日 10:35:32   作者:winycg  
这篇文章主要介绍了Pytorch 实现常用正则化的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

在这里插入图片描述

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python登录并获取CSDN博客所有文章列表代码实例

    Python登录并获取CSDN博客所有文章列表代码实例

    这篇文章主要介绍了Python登录并获取CSDN博客所有文章列表代码实例,具有一定借鉴价值,需要的朋友可以参考下
    2017-12-12
  • Python解析nginx日志文件

    Python解析nginx日志文件

    Web服务器的各种系统管理工作包括了多Nginx/Apache 日志的统计,python使这个任务变得极其简单,下面我们来详细讲解下具体的做法,有需要的小伙伴可以参考下。
    2015-05-05
  • 关于python简单的爬虫操作(requests和etree)

    关于python简单的爬虫操作(requests和etree)

    这篇文章主要介绍了关于python简单的爬虫操作(requests和etree),文中提供了实现代码,需要的朋友可以参考下
    2023-04-04
  • python写日志文件操作类与应用示例

    python写日志文件操作类与应用示例

    这篇文章主要介绍了python写日志文件操作类与应用,结合实例形式分析了Python日志文件操作类的定义与使用相关操作技巧,需要的朋友可以参考下
    2019-07-07
  • python使用in操作符时元组和数组的区别分析

    python使用in操作符时元组和数组的区别分析

    有时候要判断一个数是否在一个序列里面,这时就会用到in运算符来判断成员资格,如果条件为真时,就会返回true,条件为假时,返回一个flase。这样的运算符叫做布尔运算符,其真值叫做布尔值。
    2015-05-05
  • Python实现扩展内置类型的方法分析

    Python实现扩展内置类型的方法分析

    这篇文章主要介绍了Python实现扩展内置类型的方法,结合实例形式分析了Python嵌入内置类型扩展及子类方式扩展的具体实现技巧,需要的朋友可以参考下
    2017-10-10
  • django 读取图片到页面实例

    django 读取图片到页面实例

    这篇文章主要介绍了django 读取图片到页面实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-03-03
  • python pycharm的安装及其使用

    python pycharm的安装及其使用

    这篇文章主要介绍了python pycharm的安装及其使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • Python 读写文件的操作代码

    Python 读写文件的操作代码

    本文通过实例代码给大家介绍了Python 读写文件的操作方法,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2018-09-09
  • Python调用MySQLdb插入中文乱码的解决

    Python调用MySQLdb插入中文乱码的解决

    这篇文章主要介绍了Python调用MySQLdb插入中文乱码的解决,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-03-03

最新评论