pytorch模型保存方式
pytorch模型保存
保存模型主要分为两类:
- 保存整个模型
- 只保存模型参数
1.保存加载整个模型(不推荐)
保存整个网络模型,网络结构+权重参数
torch.save(model,'net.pth')
加载整个网络模型(可能比较耗时)
model=torch.load('net.pth')2.只保存加载模型参数(推荐)
保存模型的权重参数(速度快,占内存少)
torch.save(model.state_dict(),'net_params.pth')
load 模型参数
因为我们只保存了 模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。
model=myNet()
#将模型参数加载到新模型中,torch.load返回的是一个OrderedDict,说明.state_dict()只是把所有模型的参数都已OrderedDict的形式存下来。
state_dict=torch.load('net_params.pth')
model.load_state_dict(state_dict)Note:保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。
load_state_dict的参数strict=False new_model.load_state_dict(state_dict,strict=False)
如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key.
但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
使用Python脚本zabbix自定义key监控oracle连接状态
这篇文章主要介绍了使用Python脚本zabbix自定义key监控oracle连接状态,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下2019-08-08
Python连接Postgres/Mysql/Mongo数据库基本操作大全
在后端应用开发中,经常会用到Postgres/Mysql/Mongo这三种数据库的基本操作,今天小编就给大家详细介绍Python连接Postgres/Mysql/Mongo数据库基本操作,感兴趣的朋友一起看看吧2021-06-06
详解使用python3.7配置开发钉钉群自定义机器人(2020年新版攻略)
这篇文章主要介绍了详解使用python3.7配置开发钉钉群自定义机器人(2020年新版攻略),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧2020-04-04
Python使用concurrent.futures模块实现多进程多线程编程
Python的concurrent.futures模块可以很方便的实现多进程、多线程运行,减少了多进程带来的的同步和共享数据问题,下面就跟随小编一起了解一下concurrent.futures模块的具体使用吧2023-12-12


最新评论