pytorch加载预训练模型与自己模型不匹配的解决方案

 更新时间:2021年05月13日 16:33:18   作者:找不到服务器1703  
这篇文章主要介绍了pytorch加载预训练模型与自己模型不匹配的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • pygame游戏之旅 添加游戏界面按键图形

    pygame游戏之旅 添加游戏界面按键图形

    这篇文章主要为大家详细介绍了pygame游戏之旅的第10篇,教大家如何添加游戏界面按键图形,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-11-11
  • python自动发送QQ邮箱的完整步骤

    python自动发送QQ邮箱的完整步骤

    最近在自己学习Python爬虫,学到了用Python发送邮件,觉得这个可能以后比较实用,所以下面这篇文章主要给大家介绍了关于python自动发送QQ邮箱的相关资料,需要的朋友可以参考下
    2021-11-11
  • python爬虫爬取网页表格数据

    python爬虫爬取网页表格数据

    这篇文章主要为大家详细介绍了python爬虫爬取网页表格数据,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03
  • 实例讲解python中的协程

    实例讲解python中的协程

    在本篇文章里我们给大家通过实例讲述一下关于python中的协程相关知识点内容,需要的朋友们可以参考下。
    2018-10-10
  • Python轻量级ORM框架Peewee访问sqlite数据库的方法详解

    Python轻量级ORM框架Peewee访问sqlite数据库的方法详解

    这篇文章主要介绍了Python轻量级ORM框架Peewee访问sqlite数据库的方法,结合实例形式较为详细的分析了ORM框架的概念、功能及peewee的安装、使用及操作sqlite数据库的方法,需要的朋友可以参考下
    2017-07-07
  • Python 文件操作的详解及实例

    Python 文件操作的详解及实例

    这篇文章主要介绍了Python 文件操作的详解及实例的相关资料,希望通过本文大家能够理解掌握Python 文件操作的知识,需要的朋友可以参考下
    2017-09-09
  • Python实现登录人人网并抓取新鲜事的方法

    Python实现登录人人网并抓取新鲜事的方法

    这篇文章主要介绍了Python实现登录人人网并抓取新鲜事的方法,可实现Python模拟登陆并抓取新鲜事的功能,需要的朋友可以参考下
    2015-05-05
  • python批量解压zip文件的方法

    python批量解压zip文件的方法

    这篇文章主要介绍了python批量解压zip文件的方法,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • Python+selenium实现截图图片并保存截取的图片

    Python+selenium实现截图图片并保存截取的图片

    这篇文章介绍如何利用Selenium的方法进行截图并保存截取的图片,需要的朋友参考下本文
    2018-01-01
  • Flask框架学习笔记之模板操作实例详解

    Flask框架学习笔记之模板操作实例详解

    这篇文章主要介绍了Flask框架学习笔记之模板操作,结合实例形式详细分析了flask框架模板引擎Jinja2的模板调用、模板继承相关原理与操作技巧,需要的朋友可以参考下
    2019-08-08

最新评论