pytorch加载自定义网络权重的实现

 更新时间:2020年01月07日 14:25:31   作者:wuming无名  
今天小编就为大家分享一篇pytorch加载自定义网络权重的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):
  r"""Returns a dictionary containing a whole state of the module.

  Both parameters and persistent buffers (e.g. running averages) are
  included. Keys are corresponding parameter and buffer names.

  Returns:
    dict:
      a dictionary containing a whole state of the module

  Example::

    >>> module.state_dict().keys()
    ['bias', 'weight']

  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
  """Saves an object to a disk file.

  See also: :ref:`recommend-saving-models`

  Args:
    obj: saved object
    f: a file-like object (has to implement write and flush) or a string
      containing a file name
    pickle_module: module used for pickling metadata and objects
    pickle_protocol: can be specified to override the default protocol

  .. warning::
    If you are using Python 2, torch.save does NOT support StringIO.StringIO
    as a valid file-like object. This is because the write method should return
    the number of bytes written; StringIO.write() does not do this.

    Please use something like io.BytesIO instead.

函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python3批量创建Crowd用户并分配组

    Python3批量创建Crowd用户并分配组

    这篇文章主要介绍了Python3批量创建Crowd用户并分配组,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Python浅析多态与鸭子类型使用实例

    Python浅析多态与鸭子类型使用实例

    python是一门解释语言,但是同java等静态语言一样,是可以通过继承的方式实现多态。而且python还有一个自己的特殊实现多态的方法,就是通过鸭子类型,来实现多态
    2022-10-10
  • Python模块Typing.overload的使用场景分析

    Python模块Typing.overload的使用场景分析

    在 Python 中,typing.overload 是一个用于定义函数重载的装饰器,函数重载是指在一个类中可以定义多个相同名字但参数不同的函数,使得在调用函数时可以根据参数的不同选择不同的函数执行,这篇文章主要介绍了Python模块Typing.overload的使用,需要的朋友可以参考下
    2024-02-02
  • python机器学习案例教程——K最近邻算法的实现

    python机器学习案例教程——K最近邻算法的实现

    本篇文章主要介绍了python机器学习案例教程——K最近邻算法的实现,详细的介绍了K最近邻算法的概念和示例,具有一定的参考价值,有兴趣的可以了解一下
    2017-12-12
  • Python拼接微信好友头像大图的实现方法

    Python拼接微信好友头像大图的实现方法

    这篇文章主要介绍了Python拼接微信好友头像大图的实现方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-08-08
  • python-opencv-cv2.threshold()二值化函数的使用

    python-opencv-cv2.threshold()二值化函数的使用

    这篇文章主要介绍了python-opencv-cv2.threshold()二值化函数的使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-11-11
  • python numpy实现文件存取的示例代码

    python numpy实现文件存取的示例代码

    这篇文章主要介绍了python numpy实现文件存取的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05
  • Python实现多线程并发请求测试的脚本

    Python实现多线程并发请求测试的脚本

    这篇文章主要为大家分享了一个Python实现多线程并发请求测试的脚本,文中的示例代码简洁易懂,具有一定的借鉴价值,需要的小伙伴可以了解一下
    2023-06-06
  • Python实现实时监测可视化数据大屏

    Python实现实时监测可视化数据大屏

    实时监测的可视化数据大屏是一种非常有用的工具,可以帮助我们实时了解数据的变化和趋势,下面我们将介绍如何使用Python代码实现实时监测的可视化数据大屏,需要的可以参考一下
    2023-06-06
  • python 提取html文本的方法

    python 提取html文本的方法

    在解决自然语言处理问题时,有时你需要获得大量的文本集。互联网是文本的最大来源,但是从任意HTML页面提取文本是一项艰巨而痛苦的任务。本文将讲述python高效提取html文本的方法
    2021-05-05

最新评论