PyTorch深度学习模型的保存和加载流程详解
更新时间:2021年10月21日 09:32:00 作者:软耳朵DONG
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,这篇文章主要介绍了PyTorch模型的保存和加载流程
一、模型参数的保存和加载
-
torch.save(module.state_dict(), path)
:使用module.state_dict()
函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path
所指定的文件存放路径(常用文件格式为.pt
、.pth
或.pkl
)。 torch.nn.Module.load_state_dict(state_dict)
:从state_dict
中加载参数和缓冲区到Module
及其子类中 。torch.nn.Module.state_dict()
函数返回python
中的一个OrderedDict
类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict
中,例如:卷积层、线性层等。Python
中的字典类以“键:值
”方式存取数据,OrderedDict
是它的一个子类,实现了对字典对象中元素的排序(OrderedDict
根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict
字典对象会被当做是两个不同的对象。- 示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化网络 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 获取state_dict state_dict = net.state_dict() # 字典的遍历默认是遍历key,所以param_tensor实际上是键值 for param_tensor in state_dict: print(param_tensor,':\n',state_dict[param_tensor]) # 保存模型参数 torch.save(state_dict,"net_params.pth") # 通过加载state_dict获取模型参数 net.load_state_dict(state_dict)
输出:
二、完整模型的保存和加载
-
torch.save(module, path)
:将训练完的整个网络模型module
保存到path
所指定的文件存放路径(常用文件格式为.pt
或.pth
)。 torch.load(path)
:加载保存到path
中的整个神经网络模型。- 示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化网络 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 保存整个网络 torch.save(net,"net.pth") # 加载网络 net = torch.load("net.pth")
到此这篇关于PyTorch深度学习模型的保存和加载流程详解的文章就介绍到这了,更多相关PyTorch 模型的保存 内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Python 格式化输出_String Formatting_控制小数点位数的实例详解
在本篇文章里小编给大家整理了关于Python 格式化输出_String Formatting_控制小数点位数的实例内容,需要的朋友们参考下。2020-02-02VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法详解
这篇文章主要介绍了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法,较为详细的分析了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的具体步骤、相关命令与操作注意事项,需要的朋友可以参考下2019-07-07国产化设备鲲鹏CentOS7上源码安装Python3.7的过程详解
这篇文章主要介绍了国产化设备鲲鹏CentOS7上源码安装Python3.7,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下2022-05-05
最新评论