批标准化层 tf.keras.layers.Batchnormalization()解析

 更新时间:2023年02月21日 16:34:03   作者:壮壮不太胖^QwQ  
这篇文章主要介绍了批标准化层 tf.keras.layers.Batchnormalization(),具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

批标准化层 tf.keras.layers.Batchnormalization()

tf.keras.layers.Batchnormalization()

重要参数:

  • training:布尔值,指示图层应在训练模式还是在推理模式下运行。
  • training=True:该图层将使用当前批输入的均值和方差对其输入进行标准化。
  • training=False:该层将使用在训练期间学习的移动统计数据的均值和方差来标准化其输入。

BatchNormalization 广泛用于 Keras 内置的许多高级卷积神经网络架构,比如 ResNet50、Inception V3 和 Xception。

BatchNormalization 层通常在卷积层或密集连接层之后使用。

批标准化的实现过程

  • 求每一个训练批次数据的均值
  • 求每一个训练批次数据的方差
  • 数据进行标准化
  • 训练参数γ,β
  • 输出y通过γ与β的线性变换得到原来的数值

在训练的正向传播中,不会改变当前输出,只记录下γ与β。在反向传播的时候,根据求得的γ与β通过链式求导方式,求出学习速率以至改变权值。

对于预测阶段时所使用的均值和方差,其实也是来源于训练集。比如我们在模型训练时我们就记录下每个batch下的均值和方差,待训练完毕后,我们求整个训练样本的均值和方差期望值,作为我们进行预测时进行BN的的均值和方差。

批标准化的使用位置

原始论文讲在CNN中一般应作用与非线性激活函数之前,但是,实际上放在激活函数之后效果可能会更好。

# 放在非线性激活函数之前
model.add(tf.keras.layers.Conv2D(64, (3, 3)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))

# 放在激活函数之后
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())

tf.keras.layers.BatchNormalization使用细节

关于keras中的BatchNormalization使用,官方文档说的足够详细。本文的目的旨在说明在BatchNormalization的使用过程中容易被忽略的细节。

在BatchNormalization的Arguments参数中有trainable属性;以及在Call arguments参数中有training。两个都是bool类型。第一次看到有两个参数的时候,我有点懵,为什么需要两个?

后来在查阅资料后发现了两者的不同作用。

1,trainable是Argument参数,类似于c++中构造函数的参数一样,是构建一个BatchNormalization层时就需要传入的,至于它的作用在下面会讲到。

2,training参数时Call argument(调用参数),是运行过程中需要传入的,用来控制模型在那个模式(train还是interfere)下运行。关于这个参数,如果使用模型调用fit()的话,是可以不给的(官方推荐是不给),因为在fit()的时候,模型会自己根据相应的阶段(是train阶段还是inference阶段)决定training值,这是由learning——phase机制实现的。

重点

关于trainable=False:如果设置trainable=False,那么这一层的BatchNormalization层就会被冻结(freeze),它的trainable weights(可训练参数)(就是gamma和beta)就不会被更新。

注意:freeze mode和inference mode是两个概念。

但是,在BatchNormalization层中,如果把某一层BatchNormalization层设置为trainable=False,那么这一层BatchNormalization层将一inference mode运行,也就是说(meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python如何把嵌套列表转变成普通列表

    python如何把嵌套列表转变成普通列表

    这篇文章主要为大家详细介绍了python如何把嵌套列表转变成普通列表,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • 基于Python列表解析(列表推导式)

    基于Python列表解析(列表推导式)

    今天小编就为大家分享一篇基于Python列表解析(列表推导式),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python中sys模块常用方法与变量实例探究

    Python中sys模块常用方法与变量实例探究

    sys 模块是 Python 标准库中的一个核心模块,提供了与解释器进行交互的功能,了解 sys 模块的方法和变量对于更有效地管理和调试 Python 程序至关重要,本文将深入探讨 sys 模块的常用方法和变量,通过详细的示例代码,帮助大家更全面地了解并灵活运用这一关键模块
    2024-01-01
  • Python 拷贝对象(深拷贝deepcopy与浅拷贝copy)

    Python 拷贝对象(深拷贝deepcopy与浅拷贝copy)

    Python中的对象之间赋值时是按引用传递的,如果需要拷贝对象,需要使用标准库中的copy模块。
    2008-09-09
  • python2和python3应该学哪个(python3.6与python3.7的选择)

    python2和python3应该学哪个(python3.6与python3.7的选择)

    许多刚入门 Python 的朋友都在纠结的的问题是:我应该选择学习 python2 还是 python3,Python 3.7 已经发布了,目前Python的用户,主要使用的版本 应该是 Python3.6 和 Python2.7 ,那么是不是该转到 Python 3.7 呢
    2019-10-10
  • Python必备技巧之集合Set的使用

    Python必备技巧之集合Set的使用

    在数学中,对集合的严格定义可能是抽象的且难以掌握。但实际上可以将集合简单地认为是定义明确的不同对象的集合,通常称为元素或成员。Python 提供了一个内置的集合类型来将对象分组到一个集合中,快跟随小编一起学习一下吧
    2022-03-03
  • 基于python实现上传文件到OSS代码实例

    基于python实现上传文件到OSS代码实例

    这篇文章主要介绍了基于python实现上传文件到OSS,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Matlab中如何实现将长字符串换行写

    Matlab中如何实现将长字符串换行写

    这篇文章主要介绍了Matlab中如何实现将长字符串换行写问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-01-01
  • Pytorch如何指定device(cuda or cpu)

    Pytorch如何指定device(cuda or cpu)

    这篇文章主要介绍了Pytorch如何指定device(cuda or cpu)问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-06-06
  • Python语言异常处理测试过程解析

    Python语言异常处理测试过程解析

    这篇文章主要介绍了Python语言异常处理测试过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01

最新评论