pytorch使用resnet快速加载官方提供的预训练模型
使用resnet快速加载官方提供的预训练模型
在做神经网络的搭建过程,经常使用pytorch中的resnet作为backbone,特别是resnet50,
比如下面的这个网络设定:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
def __init__(self):
super(base_resnet, self).__init__()
self.model = models.resnet50(pretrained=True)
#self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
# x = x.view(x.size(0), x.size(1))
return x该网络相当于继承了resnet50的所有参数结构,只不过是在forward中,改变了数据的传输过程,没有经过最后的特征展开以及线性分类。
在下面的这行代码中,是相当于调用了pytoch中定义的resnet50网络,并且会自动下载并且加载训练好的网络参数,如果调为 pretrained=False,则不会加载训练好的参数,而是随机进行参数的赋值。
但是我在服务器上跑这一类代码的时候发现,每当我重新跑一次程序,如果设置为True都会重新下载resnet50训练好的参数,但是由于有时候网络特别不好,导致我下载个基础的resnet50就要耗费我好长时间,那么我就想能不能将这个resnet50的参数提前下载好,使用的时候直接加载呢。
当然是能了。
self.model = models.resnet50(pretrained=True)
我们可以根据我们使用的结构,到对应的地址下载对应的模型到本地,常用的resnet的地址如下:
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
将其下载下来,然后将模型放入到和net.py同目录的model文件夹下面,然后使用下面的代码就可以避免每次都重新下载模型的问题了。
self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))pytorch代码规范之加载预训练模型
加载预训练模型,并去除需要再次训练的层
model=resnet()#自己构建的模型,以resnet为例, 需要重新训练的层的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)固定部分参数
#k是可训练参数的名字,v是包含可训练参数的一个实体 #可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中: for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : v.requires_grad=False#固定参数
训练部分参数
#将要训练的参数放入优化器 optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)
检查是否固定
for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : print(v.requires_grad)#理想状态下,所有值都是False
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
这篇文章主要介绍了使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-05-05
一文让你彻底搞懂Python中__str__和__repr__
这篇文章主要介绍了Python中的__str__和__repr__的异同,__str__和__repr__是基本的内置方法,文中有详细的代码示例,感兴趣的同学可以参考阅读下2023-05-05
Python处理PDF文档的两大功能库(PyPDF2/pdfplumber)的使用指南
PyPDF2擅长PDF文档操作,而pdfplumber专注于高质量文本和表格数据提取,本文将对比Python中两个PDF处理库PyPDF2和pdfplumber的主要功能与适用场景,希望对大家有所帮助2026-03-03


最新评论