PyTorch加载模型model.load_state_dict()问题及解决

 更新时间:2023年02月03日 14:16:42   作者:是否龙磊磊真的一无所有  
这篇文章主要介绍了PyTorch加载模型model.load_state_dict()问题及解决,具有很好的参考价值,希望对大家有所帮助。

PyTorch加载模型model.load_state_dict()问题

希望将训练好的模型加载到新的网络上。

如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。

Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。

表明了加载过程中,期望获得的key值为feature...,而不是module.features....。

这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

解决上面的问题有三个办法: 

1. 对load的模型创建新的字典

去掉不需要的key值"module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
    new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. 直接用空白''代替'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
 
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. 最简单的方法

加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。

如果有多个GPU,将模型并行化,用DataParallel来操作。

这个过程会将key值加一个"module. ***"。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

4. 总结

从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。

这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。

或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
    print(k) 
print("*****************************************")
 

输出结果为:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

可以看出不匹配,模型的参数中,key值不同,多了module。

PS: 追加

在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:

from collections import OrderedDict
checkpoint = torch.load(
    pretrained_model_file_path,
    map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if not k.endswith('total_ops') and not k.endswith('total_params'):
        name = k[7:]
        new_state_dict[name] = v

最后

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • tornado捕获和处理404错误的方法

    tornado捕获和处理404错误的方法

    这篇文章主要介绍了tornado捕获和处理404错误的方法,方法很简单,只要覆写write_error方法就可以,看下面的代码就明白了
    2014-02-02
  • Pandas读取csv时如何设置列名

    Pandas读取csv时如何设置列名

    这篇文章主要介绍了Pandas读取csv时如何设置列名,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • python 简单的股票基金爬虫

    python 简单的股票基金爬虫

    最近基金非常火爆,很多原本不投资、不理财人,也开始讨论、参与买基金了。根据投资对象的不同,基金分为股票型基金、债券基金、混合型基金、货币基金。所以今天我们就来看看,这些基金公司都喜欢买那些公司的股票。
    2021-06-06
  • Python爬虫破解登陆哔哩哔哩的方法

    Python爬虫破解登陆哔哩哔哩的方法

    这篇文章主要介绍了Python爬虫破解登陆哔哩哔哩的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • Python深入学习之上下文管理器

    Python深入学习之上下文管理器

    这篇文章主要介绍了Python深入学习之上下文管理器,上下文管理器是在Python2.5加入的功能,它能够让你的代码可读性更强并且错误更少,和C#中的using语句类似,需要的朋友可以参考下
    2014-08-08
  • Pycharm导入Python包,模块的图文教程

    Pycharm导入Python包,模块的图文教程

    今天小编就为大家分享一篇Pycharm导入Python包,模块的图文教程,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • 为什么是 Python -m

    为什么是 Python -m

    这篇文章给大家介绍了Python -m的含义及python -m 和 python 的区别解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-06-06
  • pytorch 计算ConvTranspose1d输出特征大小方式

    pytorch 计算ConvTranspose1d输出特征大小方式

    这篇文章主要介绍了pytorch 计算ConvTranspose1d输出特征大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 说说如何遍历Python列表的方法示例

    说说如何遍历Python列表的方法示例

    这篇文章主要介绍了如何遍历Python列表的方法示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-02-02
  • Python基于当前时间批量创建文件

    Python基于当前时间批量创建文件

    这篇文章主要介绍了Python基于当前时间批量创建文件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05

最新评论