Pytorch如何加载部分权重

 更新时间:2023年09月15日 10:14:16   作者:Mr_寒路  
这篇文章主要介绍了Pytorch如何加载部分权重问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

1.修改网络层输出

比如在人脸检测项目中,已经训练好人脸框的回归,但是此时需要再加入人脸关键点。

为了节约大量时间,我们可以加载部分权重。加载的网络权重

if os.path.exists(self.load_params):
	pretext_model = torch.load(self.load_params)

打印出来,会看到网络权重存储在一个字典中,需要修改哪一层,用字典的键索引值进行修改。

比如原本输出层为4,我将网络输出层修改为14,又由于输出的都是坐标值,属于同一分布,所以我将原参4复制扩充为了14,效果非常好。

w = pretext_model["fc2.weight"]
b = pretext_model["fc2.bias"]
pretext_model["fc2.weight"] = torch.cat((w,w,w,w[:2]),dim=0)
pretext_model["fc2.bias"] = torch.cat((b,b,b,b[:2]),dim=0)

最后加载修改后的参数

self.net.load_state_dict(pretext_model)

2.删除或增加了网络层

查看模型的参数,也是存放在一个字典中

if os.path.exists(self.load_params):
	pretext_model = torch.load(self.load_params) #加载的参数
	model_dict = net.state_dict()  #模型参数
	print(model_dict)
	print(pretext_model)
#如果模型有k层,就加载
state_dict = {k: v for k, v in pretext_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
net.load_state_dict(model_dict)

3.迁移学习

有时我们也会用别人的模型,加载与训练参数,但是需要对输出层做一些修改,一般有两种方法,直接修改输出层个数或增加网络层

修改输出层个数

net = models.vgg19(pretrained=True) #下载与训练参数
print(net)  #查看网络结构
net.classifier[6] = torch.nn.Linear(4096,10) #将输出层修改为10分类

增加输出网络层

num_fc_ftr = net.classifier[6]
net.fc = nn.Linear(num_fc_ftr, 128)
net.out = nn.Linear(128, 10)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 详解Selenium如何实现获取cookies并保存

    详解Selenium如何实现获取cookies并保存

    这篇文章主要为大家详细介绍了Selenium如何实现获取cookies保存起来用于下次访问,文中的示例代码讲解详细,感兴趣的小伙伴可以了解一下
    2023-05-05
  • python随机打印成绩排名表

    python随机打印成绩排名表

    这篇文章主要为大家详细介绍了python随机打印成绩排名表,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-06-06
  • Opencv图像添加椒盐噪声、高斯滤波去除噪声原理以及手写Python代码实现方法

    Opencv图像添加椒盐噪声、高斯滤波去除噪声原理以及手写Python代码实现方法

    椒盐噪声的特征非常明显,为图像上有黑色和白色的点,下面这篇文章主要给大家介绍了关于Opencv图像添加椒盐噪声、高斯滤波去除噪声原理以及手写Python代码实现的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-09-09
  • 基于Python编写PDF转EPUB以及MOBI工具

    基于Python编写PDF转EPUB以及MOBI工具

    当我们需要在电子阅读器上阅读这些文档时,转换为EPUB或MOBI格式会提供更好的阅读体验,所以本文将使用Python编写一个PDF转EPUB以及MOBI工具,需要的可以参考下
    2025-03-03
  • Python docx库用法示例分析

    Python docx库用法示例分析

    这篇文章主要介绍了Python docx库用法,结合实例形式分析了docx库相关的docx文件读取、文本添加、格式操作,需要的朋友可以参考下
    2019-02-02
  • django 数据库 get_or_create函数返回值是tuple的问题

    django 数据库 get_or_create函数返回值是tuple的问题

    这篇文章主要介绍了django 数据库 get_or_create函数返回值是tuple的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • python实现抽奖小程序

    python实现抽奖小程序

    这篇文章主要为大家详细介绍了python实现抽奖小程序,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-05-05
  • Python报错error: subprocess-exited-with-error解决办法

    Python报错error: subprocess-exited-with-error解决办法

    在Python开发中,遇到subprocess-exited-with-error通常是由依赖缺失、权限问题、环境配置错误或兼容性问题导致,修复方法包括安装依赖、使用虚拟环境、提升权限、检查路径和命令,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2024-10-10
  • Python3爬虫学习之爬虫利器Beautiful Soup用法分析

    Python3爬虫学习之爬虫利器Beautiful Soup用法分析

    这篇文章主要介绍了Python3爬虫学习之爬虫利器Beautiful Soup用法,结合实例形式分析了Beautiful Soup的功能、使用方法及相关操作注意事项,需要的朋友可以参考下
    2018-12-12
  • Python使用cx_Oracle调用Oracle存储过程的方法示例

    Python使用cx_Oracle调用Oracle存储过程的方法示例

    这篇文章主要介绍了Python使用cx_Oracle调用Oracle存储过程的方法,结合具体实例分析了Python中通过cx_Oracle调用PL/SQL的具体步骤与相关操作技巧,需要的朋友可以参考下
    2017-10-10

最新评论