pytorch使用resnet快速加载官方提供的预训练模型

 更新时间:2023年09月09日 09:34:03   作者:Tchunren  
这篇文章主要介绍了pytorch使用resnet快速加载官方提供的预训练模型方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

使用resnet快速加载官方提供的预训练模型

在做神经网络的搭建过程,经常使用pytorch中的resnet作为backbone,特别是resnet50,

比如下面的这个网络设定:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
    def __init__(self):
        super(base_resnet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        # x = x.view(x.size(0), x.size(1))
        return x

该网络相当于继承了resnet50的所有参数结构,只不过是在forward中,改变了数据的传输过程,没有经过最后的特征展开以及线性分类。

在下面的这行代码中,是相当于调用了pytoch中定义的resnet50网络,并且会自动下载并且加载训练好的网络参数,如果调为 pretrained=False,则不会加载训练好的参数,而是随机进行参数的赋值。

但是我在服务器上跑这一类代码的时候发现,每当我重新跑一次程序,如果设置为True都会重新下载resnet50训练好的参数,但是由于有时候网络特别不好,导致我下载个基础的resnet50就要耗费我好长时间,那么我就想能不能将这个resnet50的参数提前下载好,使用的时候直接加载呢。

当然是能了。

self.model = models.resnet50(pretrained=True)

我们可以根据我们使用的结构,到对应的地址下载对应的模型到本地,常用的resnet的地址如下:

 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',

将其下载下来,然后将模型放入到和net.py同目录的model文件夹下面,然后使用下面的代码就可以避免每次都重新下载模型的问题了。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

pytorch代码规范之加载预训练模型

加载预训练模型,并去除需要再次训练的层

model=resnet()#自己构建的模型,以resnet为例, 需要重新训练的层的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
v.requires_grad=False#固定参数

训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

检查是否固定

for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
print(v.requires_grad)#理想状态下,所有值都是False

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Pycharm激活方法及详细教程(详细且实用)

    Pycharm激活方法及详细教程(详细且实用)

    这篇文章主要介绍了Pycharm激活方法及详细教程,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-05-05
  • python通过ElementTree操作XML

    python通过ElementTree操作XML

    这篇文章介绍了python通过ElementTree操作XML的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-07-07
  • Python下opencv使用hough变换检测直线与圆

    Python下opencv使用hough变换检测直线与圆

    在数字图像中,往往存在着一些特殊形状的几何图形,像检测马路边一条直线,检测人眼的圆形等等,有时我们需要把这些特定图形检测出来,本文就详细的介绍了一下方法
    2021-06-06
  • E: 无法定位软件包 python3-pip问题及解决

    E: 无法定位软件包 python3-pip问题及解决

    这篇文章主要介绍了E: 无法定位软件包 python3-pip问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-02-02
  • Python实现二叉堆

    Python实现二叉堆

    二叉堆是一种特殊的堆,二叉堆是完全二元树(二叉树)或者是近似完全二元树(二叉树)。二叉堆有两种:最大堆和最小堆。最大堆:父结点的键值总是大于或等于任何一个子节点的键值;最小堆:父结点的键值总是小于或等于任何一个子节点的键值。
    2016-02-02
  • python教程对函数中的参数进行排序

    python教程对函数中的参数进行排序

    这篇文章主要介绍了python教程对函数中的参数进行排序的方法讲解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2021-09-09
  • Datawhale练习之二手车价格预测

    Datawhale练习之二手车价格预测

    此篇文章是关于Datawhale练习,代码完整,但由于该数据集中数据特征较少(39维),以下可作为少量特征情况下的分析。当特征数目过大(成千上万)时,需要继续学习。需要的朋友可以参考下
    2021-04-04
  • 详解如何用TensorFlow训练和识别/分类自定义图片

    详解如何用TensorFlow训练和识别/分类自定义图片

    这篇文章主要介绍了详解如何用TensorFlow训练和识别/分类自定义图片,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • python安装numpy&安装matplotlib& scipy的教程

    python安装numpy&安装matplotlib& scipy的教程

    下面小编就为大家带来一篇python安装numpy&安装matplotlib& scipy的教程。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2017-11-11
  • 一篇文章带你深入学习Python函数

    一篇文章带你深入学习Python函数

    这篇文章主要带大家深入学习Python函数,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-01-01

最新评论