浅析pytorch中对nn.BatchNorm2d()函数的理解

 更新时间:2023年11月15日 16:48:30   作者:Code_LiShi  
Batch Normalization强行将数据拉回到均值为0,方差为1的正太分布上,一方面使得数据分布一致,另一方面避免梯度消失,这篇文章主要介绍了pytorch中对nn.BatchNorm2d()函数的理解,需要的朋友可以参考下

简介

机器学习中,进行模型训练之前,需对数据做归一化处理,使其分布一致。在深度神经网络训练过程中,通常一次训练是一个batch,而非全体数据。每个batch具有不同的分布产生了internal covarivate shift问题——在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。Batch Normalization强行将数据拉回到均值为0,方差为1的正太分布上,一方面使得数据分布一致,另一方面避免梯度消失。

计算

如图所示:

3. Pytorch的nn.BatchNorm2d()函数

其主要需要输入4个参数:
(1)num_features:输入数据的shape一般为[batch_size, channel, height, width], num_features为其中的channel;
(2)eps: 分母中添加的一个值,目的是为了计算的稳定性,默认:1e-5;
(3)momentum: 一个用于运行过程中均值和方差的一个估计参数,默认值为0.1.

(4)affine:当设为true时,给定可以学习的系数矩阵 γ \gamma γ和 β \beta β

4 代码示例

import torch
data = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)
print("\n")
print("=========================使用封装的BatchNorm2d()计算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)
print("\n")
print("=========================自行计算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.将同一通道进行拼接(即把同一通道当作一个整体)
x_mean = torch.Tensor.mean(x)                       # 2.计算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.计算同一通道所有制的方差(即拼接后的方差)
# 4.使用第一个数按照公式来求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)

到此这篇关于pytorch中对nn.BatchNorm2d()函数的理解的文章就介绍到这了,更多相关pytorch nn.BatchNorm2d()函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 解决tensorflow1.x版本加载saver.restore目录报错的问题

    解决tensorflow1.x版本加载saver.restore目录报错的问题

    今天小编就为大家分享一篇解决tensorflow1.x版本加载saver.restore目录报错的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • Python中的装饰器用法详解

    Python中的装饰器用法详解

    这篇文章主要介绍了Python中的装饰器用法,以实例形式详细的分析了Python中的装饰器的使用技巧及相关注意事项,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-01-01
  • python中with用法讲解

    python中with用法讲解

    在本篇文章里小编给大家整理的是关于python中with用法讲解内容,有需要的朋友们可以参考下。
    2020-02-02
  • python仿抖音表白神器

    python仿抖音表白神器

    这篇文章主要教大家制作python抖音表白神器,仿制抖音表白小软件,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-04-04
  • 使用seaborn绘制强化学习中的图片问题

    使用seaborn绘制强化学习中的图片问题

    这篇文章主要介绍了使用seaborn绘制强化学习中的图片问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-01-01
  • Python爬取网页的所有内外链的代码

    Python爬取网页的所有内外链的代码

    这篇文章主要介绍了Python爬取网页的所有内外链,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-04-04
  • Django传递数据给前端的3种方式小结

    Django传递数据给前端的3种方式小结

    Django从后台往前台传递数据时有多种方法可以实现,下面这篇文章主要给大家介绍了关于Django传递数据给前端的3种方式,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2024-01-01
  • Python 异步之在 Asyncio中如何运行阻塞任务详解

    Python 异步之在 Asyncio中如何运行阻塞任务详解

    这篇文章主要为大家介绍了Python 异步之在 Asyncio 中运行阻塞任务示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-03-03
  • python计算时间差的方法

    python计算时间差的方法

    这篇文章主要介绍了python计算时间差的方法,实例分析了Python时间操作的相关模块与技巧,需要的朋友可以参考下
    2015-05-05
  • python爬取抖音视频的实例分析

    python爬取抖音视频的实例分析

    在本篇内容里小编给大家整理一篇关于python爬取抖音视频的实例分析的相关内容,有兴趣的朋友可以测试下实例内容。
    2021-01-01

最新评论