Pytorch 实现权重初始化

 更新时间:2019年12月31日 15:51:09   作者:idotc  
今天小编就为大家分享一篇Pytorch 实现权重初始化,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Tensorflow tf.nn.atrous_conv2d如何实现空洞卷积的

    Tensorflow tf.nn.atrous_conv2d如何实现空洞卷积的

    这篇文章主要介绍了Tensorflow tf.nn.atrous_conv2d如何实现空洞卷积的,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-04-04
  • 基于Python+OpenCV实现自动扫雷功能

    基于Python+OpenCV实现自动扫雷功能

    相信许多人很早就知道有扫雷这么一款经典的游(显卡测试)戏(软件),扫雷作为一款在Windows9x时代就已经诞生的经典游戏,从过去到现在依然都有着它独特的魅力,所以本文小编给大家介绍了如何使用Python+OpenCV实现自动扫雷效果,感兴趣的朋友可以参考下
    2023-12-12
  • pytorch获取模型某一层参数名及参数值方式

    pytorch获取模型某一层参数名及参数值方式

    今天小编就为大家分享一篇pytorch获取模型某一层参数名及参数值方式,具有很好的价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python中random函数的用法整理大全

    Python中random函数的用法整理大全

    random库是使用随机数的Python标准库,random库主要用于生成随机数,下面这篇文章主要给大家介绍了关于Python random函数用法的相关资料,文中通过图文以及实例代码介绍的非常详细,需要的朋友可以参考下
    2022-08-08
  • Python3.7 读取音频根据文件名生成脚本的代码

    Python3.7 读取音频根据文件名生成脚本的代码

    这篇文章主要介绍了Python3.7 读取音频根据文件名生成字幕脚本的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-04-04
  • Python开发必备知识内存管理与垃圾回收

    Python开发必备知识内存管理与垃圾回收

    Python是一种高级编程语言,因其简洁而强大而备受欢迎,然而如其他编程语言一样,Python也面临着内存管理的挑战,在Python中,垃圾回收是一项关键任务,用于自动释放不再使用的内存,以避免内存泄漏,本文将介绍Python中的垃圾回收机制,以及如何通过优化代码来提高性能
    2023-11-11
  • Python使用pydub模块转换音频格式以及对音频进行剪辑

    Python使用pydub模块转换音频格式以及对音频进行剪辑

    这篇文章主要给大家介绍了关于Python使用pydub模块转换音频格式以及对音频进行剪辑的相关资料pydub是python的高级一个音频处理库,可以让你以一种不那么蠢的方法处理音频。需要的朋友可以参考下
    2021-06-06
  • Python实现五子棋联机对战小游戏

    Python实现五子棋联机对战小游戏

    本文主要介绍了通过Python实现简单的支持联机对战的游戏——支持局域网联机对战的五子棋小游戏。废话不多说,快来跟随小编一起学习吧
    2021-12-12
  • Python json读写方式和字典相互转化

    Python json读写方式和字典相互转化

    这篇文章主要介绍了Python json读写方式和字典相互转化,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • 单利模式及python实现方式详解

    单利模式及python实现方式详解

    单例模式(Singleton Pattern)是一种常用的软件设计模式,该模式的主要目的是确保 某一个类只有一个实例存在.这篇文章主要介绍了单利模式及python实现方式及Python单例模式的4种实现方法,需要的朋友可以参考下
    2018-03-03

最新评论