解决pytorch 的state_dict()拷贝问题

 更新时间:2021年03月03日 11:24:59   作者:Luke_Ye  
这篇文章主要介绍了解决pytorch 的state_dict()拷贝问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

相关文章

  • Python-Selenium自动化爬虫

    Python-Selenium自动化爬虫

    本文介绍Python-Selenium自动化爬虫,Selenium是一个Web的自动化测试工具,最初是为网站自动化测试而开发的,Selenium 可以直接运行在浏览器上,它支持所有主流的浏览器,可以接收指令,让浏览器自动加载页面,获取需要的数据,甚至页面截屏,xiamian neir 需要的朋友可以参考下
    2022-01-01
  • Python ini配置文件示例详解

    Python ini配置文件示例详解

    这篇文章主要给大家介绍了关于Python ini配置文件的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2022-03-03
  • 深入浅析ImageMagick命令执行漏洞

    深入浅析ImageMagick命令执行漏洞

    ImageMagick是一个功能强大的开源图形处理软件,可以用来读、写和处理超过90种的图片文件,包括流行的JPEG、GIF、 PNG、PDF以及PhotoCD等格式。接下来通过本文给大家浅析ImageMagick命令执行漏洞的知识,一起看看吧
    2016-10-10
  • Python检测PE所启用保护方式详解

    Python检测PE所启用保护方式详解

    Python通过pywin32模块调用WindowsAPI接口,可以实现对特定进程加载模块的枚举输出并检测该PE程序模块所启用的保护方式,感兴趣的可以了解一下
    2022-10-10
  • Python连接字符串过程详解

    Python连接字符串过程详解

    这篇文章主要介绍了python连接字符串过程详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • 如何将pytorch模型部署到安卓上的方法示例

    如何将pytorch模型部署到安卓上的方法示例

    这篇文章演示如何将训练好的pytorch模型部署到安卓设备上,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-02-02
  • Python Django路径配置实现过程解析

    Python Django路径配置实现过程解析

    这篇文章主要介绍了Python Django路径配置实现过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • Python中表格插件Tabulate的用法小结

    Python中表格插件Tabulate的用法小结

    这篇文章主要介绍了Python中表格插件Tabulate的用法,Tabulate插件是一个功能强大、简单易用的数据可视化工具,它能够满足我们在Python中进行表格数据展示的各种需求,通过使用Tabulate插件,我们能够轻松地生成美观且易读的表格,需要的朋友可以参考下
    2023-11-11
  • python装饰器代替set get方法实例

    python装饰器代替set get方法实例

    今天小编就为大家分享一篇python装饰器代替set get方法实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • 轻松掌握python设计模式之策略模式

    轻松掌握python设计模式之策略模式

    这篇文章主要帮助大家轻松掌握python设计模式之策略模式,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2016-11-11

最新评论