解决Pytorch中Batch Normalization layer踩过的坑

 更新时间:2021年05月27日 09:48:58   作者:机器AI  
这篇文章主要介绍了解决Pytorch中Batch Normalization layer踩过的坑,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

1. 注意momentum的定义

Pytorch中的BN层的动量平滑和常见的动量法计算方式是相反的,默认的momentum=0.1

BN层里的表达式为:

其中γ和β是可以学习的参数。在Pytorch中,BN层的类的参数有:

CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

每个参数具体含义参见文档,需要注意的是,affine定义了BN层的参数γ和β是否是可学习的(不可学习默认是常数1和0).

2. 注意BN层中含有统计数据数值,即均值和方差

track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True

在训练过程中model.train(),train过程的BN的统计数值—均值和方差是通过当前batch数据估计的。

并且测试时,model.eval()后,若track_running_stats=True,模型此刻所使用的统计数据是Running status 中的,即通过指数衰减规则,积累到当前的数值。否则依然使用基于当前batch数据的估计值。

3. BN层的统计数据更新

是在每一次训练阶段model.train()后的forward()方法中自动实现的,而不是在梯度计算与反向传播中更新optim.step()中完成

4. 冻结BN及其统计数据

从上面的分析可以看出来,正确的冻结BN的方式是在模型训练时,把BN单独挑出来,重新设置其状态为eval (在model.train()之后覆盖training状态).

解决方案:

You should use apply instead of searching its children, while named_children() doesn't iteratively search submodules.

def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
      m.eval()
model.apply(set_bn_eval)

或者,重写module中的train()方法:

def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super(MyNet, self).train(mode)
        if self.freeze_bn:
            print("Freezing Mean/Var of BatchNorm2D.")
            if self.freeze_bn_affine:
                print("Freezing Weight/Bias of BatchNorm2D.")
        if self.freeze_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    if self.freeze_bn_affine:
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

5. Fix/frozen Batch Norm when training may lead to RuntimeError: expected scalar type Half but found Float

解决办法:

import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *
def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()
model = models.resnet50(pretrained=True)
model.cuda()
model = network_to_half(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda().half())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()

Please do

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval().half()

Reason for this is, for regular training it is better (performance-wise) to use cudnn batch norm, which requires its weights to be in fp32, thus batch norm modules are not converted to half in network_to_half. However, cudnn does not support batchnorm backward in the eval mode , which is what you are doing, and to use pytorch implementation for this, weights have to be of the same type as inputs.

补充:深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方

用pytorch做dropout和BN时需要注意的地方

pytorch做dropout:

就是train的时候使用dropout,训练的时候不使用dropout,

pytorch里面是通过net.eval()固定整个网络参数,包括不会更新一些前向的参数,没有dropout,BN参数固定,理论上对所有的validation set都要使用net.eval()

net.train()表示会纳入梯度的计算。

net_dropped = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDEN),
    torch.nn.Dropout(0.5),  # drop 50% of the neuron
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, N_HIDDEN),
    torch.nn.Dropout(0.5),  # drop 50% of the neuron
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, 1),
)
for t in range(500):
    pred_drop = net_dropped(x)
    loss_drop = loss_func(pred_drop, y)
    optimizer_drop.zero_grad()
    loss_drop.backward()
    optimizer_drop.step()
    if t % 10 == 0:
        # change to eval mode in order to fix drop out effect
        net_dropped.eval()  # parameters for dropout differ from train mode
        test_pred_drop = net_dropped(test_x)
        # change back to train mode
        net_dropped.train()

pytorch做Batch Normalization:

net.eval()固定整个网络参数,固定BN的参数,moving_mean 和moving_var,不懂这个看下图:

            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
                self.bns.append(bn)
    for epoch in range(EPOCH):
        print('Epoch: ', epoch)
        for net, l in zip(nets, losses):
            net.eval()              # set eval mode to fix moving_mean and moving_var
            pred, layer_input, pre_act = net(test_x)
            net.train()             # free moving_mean and moving_var
        plot_histogram(*layer_inputs, *pre_acts)  

moving_mean 和moving_var

在这里插入图片描述

用tensorflow做dropout和BN时需要注意的地方

dropout和BN都有一个training的参数表明到底是train还是test, 表明test那dropout就是不dropout,BN就是固定住了BN的参数;

tf_is_training = tf.placeholder(tf.bool, None)  # to control dropout when training and testing
# dropout net
d1 = tf.layers.dense(tf_x, N_HIDDEN, tf.nn.relu)
d1 = tf.layers.dropout(d1, rate=0.5, training=tf_is_training)   # drop out 50% of inputs
d2 = tf.layers.dense(d1, N_HIDDEN, tf.nn.relu)
d2 = tf.layers.dropout(d2, rate=0.5, training=tf_is_training)   # drop out 50% of inputs
d_out = tf.layers.dense(d2, 1)
for t in range(500):
    sess.run([o_train, d_train], {tf_x: x, tf_y: y, tf_is_training: True})  # train, set is_training=True
    if t % 10 == 0:
        # plotting
        plt.cla()
        o_loss_, d_loss_, o_out_, d_out_ = sess.run(
            [o_loss, d_loss, o_out, d_out], {tf_x: test_x, tf_y: test_y, tf_is_training: False} # test, set is_training=False
        )
# pytorch
    def add_layer(self, x, out_size, ac=None):
        x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)
        self.pre_activation.append(x)
        # the momentum plays important rule. the default 0.99 is too high in this case!
        if self.is_bn: x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train)    # when have BN
        out = x if ac is None else ac(x)
        return out

当BN的training的参数为train时,只是表示BN的参数是可变化的,并不是代表BN会自己更新moving_mean 和moving_var,因为这个操作是前向更新的op,在做train之前必须确保moving_mean 和moving_var更新了,更新moving_mean 和moving_var的操作在tf.GraphKeys.UPDATE_OPS

 # !! IMPORTANT !! the moving_mean and moving_variance need to be updated,
        # pass the update_ops with control_dependencies to the train_op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 利用python检测文本相似性的三种方法

    利用python检测文本相似性的三种方法

    文本查重,也称为文本去重,是一项旨在识别文本文档之间的相似性或重复性的技术或任务,它的主要目标是确定一个文本文档是否包含与其他文档相似或重复的内容,本文给大家介绍了利用python检测文本相似性的原理和方法,需要的朋友可以参考下
    2023-11-11
  • python 自动监控最新邮件并读取的操作

    python 自动监控最新邮件并读取的操作

    这篇文章主要介绍了python 自动监控最新邮件并读取的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python常见库matplotlib学习笔记之画图文字的中文显示

    Python常见库matplotlib学习笔记之画图文字的中文显示

    在Python中使用matplotlib或者plotnine模块绘图时,常常出现图表中无法正常显示中文的问题,下面这篇文章主要给大家介绍了关于Python常见库matplotlib学习笔记之画图文字的中文显示的相关资料,需要的朋友可以参考下
    2023-05-05
  • django ajax json的实例代码

    django ajax json的实例代码

    今天就为大家分享一篇django ajax json的实例代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • pd.read_csv读取文件路径出现的问题解决

    pd.read_csv读取文件路径出现的问题解决

    本文主要介绍了pd.read_csv读取文件路径出现的问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-06-06
  • python数据结构的排序算法

    python数据结构的排序算法

    下面是是对python数据结构的排序算法的一些讲解及示意图,感兴趣的小伙伴一起来学习吧
    2021-08-08
  • Python数据结构之双向链表的定义与使用方法示例

    Python数据结构之双向链表的定义与使用方法示例

    这篇文章主要介绍了Python数据结构之双向链表的定义与使用方法,结合实例形式分析了Python双向链表的概念、原理、使用方法及相关注意事项,需要的朋友可以参考下
    2018-01-01
  • python绘制春节烟花的示例代码

    python绘制春节烟花的示例代码

    这篇文章主要介绍了使用python 实现的简单春节烟花效果的示例代码,请注意,运行本文的代码之前,请确保计算机上已经安装了Pygame库,需要的朋友可以参考下
    2024-02-02
  • python爬取天气数据的实例详解

    python爬取天气数据的实例详解

    在本篇文章里小编给大家整理的是一篇关于python爬取天气数据的实例详解内容,有兴趣的朋友们学习下。
    2020-11-11
  • Python字符串中如何去除数字之间的逗号

    Python字符串中如何去除数字之间的逗号

    这篇文章主要介绍了Python字符串中如何去除数字之间的逗号,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05

最新评论