Pytorch中实现只导入部分模型参数的方式

 更新时间:2020年01月02日 16:57:09   作者:咆哮的阿杰  
今天小编就为大家分享一篇Pytorch中实现只导入部分模型参数的方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Pandas中Series的创建及数据类型转换

    Pandas中Series的创建及数据类型转换

    这篇文章主要介绍了Pandas中Series的创建及数据类型转换,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-08-08
  • python爬虫scrapy基于CrawlSpider类的全站数据爬取示例解析

    python爬虫scrapy基于CrawlSpider类的全站数据爬取示例解析

    这篇文章主要介绍了python爬虫scrapy基于CrawlSpider类的全站数据爬取示例解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • Python中闭包和自由变量的使用与注意事项

    Python中闭包和自由变量的使用与注意事项

    这篇文章主要给大家介绍了关于Python中闭包和自由变量的相关资料,需要的朋友可以参考下
    2022-03-03
  • Python利用Turtle绘制虎年图像

    Python利用Turtle绘制虎年图像

    2022年是农历壬寅虎年,在自然界中,虎有“百兽之王”之称。本文也将利用Python中的Turtle绘制一个卡通的虎年图像,感兴趣的可以学习一下
    2022-01-01
  • python 教程之blinker 信号库

    python 教程之blinker 信号库

    这篇文章主要介绍了python 教程之blinker 信号库,文章基于python的相关资料展开详细的内容说明。具有一定的参考价价值,需要的小伙伴可以参考一下
    2022-05-05
  • VSCode搭建Django开发环境的图文步骤

    VSCode搭建Django开发环境的图文步骤

    本篇介绍在vscode环境下搭建Django开发环境的详细步骤,包括Python、Django、VSCode等,以及它们的安装和配置方法,具有一定的参考价值,感兴趣的可以了解一下
    2023-09-09
  • openCV入门学习基础教程第一篇

    openCV入门学习基础教程第一篇

    OpenCV是计算机视觉领域一个图像和视频处理库,用于各种图像和视频分析,如面部识别和检测,车牌阅读,照片编辑,高级机器人视觉,光学字符识别等等,下面这篇文章主要给大家介绍了关于openCV入门学习基础教程第一篇的相关资料,需要的朋友可以参考下
    2022-11-11
  • python多个模块py文件的数据共享实例

    python多个模块py文件的数据共享实例

    今天小编就为大家分享一篇python多个模块py文件的数据共享实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Pycharm Plugins加载失败问题解决方案

    Pycharm Plugins加载失败问题解决方案

    这篇文章主要介绍了Pycharm Plugins加载失败问题解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • 解读keras中的正则化(regularization)问题

    解读keras中的正则化(regularization)问题

    这篇文章主要介绍了解读keras中的正则化(regularization)问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-12-12

最新评论