Python中torch.load()加载模型以及其map_location参数详解

 更新时间:2022年09月23日 09:55:41   作者:eecspan  
torch.load()作用用来加载torch.save()保存的模型文件,下面这篇文章主要给大家介绍了关于Python中torch.load()加载模型以及其map_location参数的相关资料,需要的朋友可以参考下

参考

TORCH.LOAD

torch.load()

函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

模型的保存

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。

具体可参考:PyTorch模型的保存与加载

模型加载中的map_location参数

具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

map_location=None

我们先把state_dict加载进来。

model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)

结果为:

cuda:0

因为保存的时候就是模型就是cuda:0的,所以加载进来也是。

map_location=torch.device()

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
print(next(model.parameters()).device)

结果为:

cpu

模型从cuda:0变成了cpu

map_location={xx:xx}

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
print(next(model.parameters()).device)

结果为:

cuda:1

模型从cuda:0变成了cuda:1

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})
print(next(model.parameters()).device)

结果为:

cuda:0

模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。

总结

到此这篇关于Python中torch.load()加载模型以及其map_location参数详解的文章就介绍到这了,更多相关torch.load()加载模型map_location参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Python中csv文件的写入与读取方法例子

    Python中csv文件的写入与读取方法例子

    这篇文章主要给大家介绍了关于Python中csv文件的写入与读取方法的相关资料,csv是"Comma-Separated Values(逗号分割的值)"的首字母缩写,它其实和txt文件一样,都是纯文本文件,使用Python来读写csv文件是非常容易的,需要的朋友可以参考下
    2023-09-09
  • 基于pycharm导入模块显示不存在的解决方法

    基于pycharm导入模块显示不存在的解决方法

    今天小编就为大家分享一篇基于pycharm导入模块显示不存在的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • 关于Python中的向量相加和numpy中的向量相加效率对比

    关于Python中的向量相加和numpy中的向量相加效率对比

    今天小编就为大家分享一篇关于Python中的向量相加和numpy中的向量相加效率对比,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python如何读取相对路径文件

    Python如何读取相对路径文件

    这篇文章主要介绍了Python如何读取相对路径文件问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-03-03
  • 关于jupyter lab安装及导入tensorflow找不到模块的问题

    关于jupyter lab安装及导入tensorflow找不到模块的问题

    这篇文章主要介绍了关于jupyter lab安装及导入tensorflow找不到模块的问题,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-03-03
  • Python CleverCSV轻松处理CSV文件指南

    Python CleverCSV轻松处理CSV文件指南

    这篇文章主要为大家介绍了Python CleverCSV轻松处理CSV文件全面指南,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2024-01-01
  • python进度条库tqdm的基本操作方法

    python进度条库tqdm的基本操作方法

    这篇文章主要介绍了python进度条库tqdm的基本操作方法,tqdm实时输出处理进度而且占用的CPU资源非常少,支持windows、Linux、mac等系统,支持循环处理、多进程、递归处理、还可以结合linux的命令来查看处理情况等优点,下面对其更多内容详细介绍,需要的朋友可以参考一下
    2022-03-03
  • Python光学仿真学习衍射算法初步理解

    Python光学仿真学习衍射算法初步理解

    这篇文章主要为大家介绍了Python光学仿真学习中对衍射算法的初步理解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步
    2021-10-10
  • Python协程的实现方式小结

    Python协程的实现方式小结

    协程是Python中强大的并发编程工具,允许开发者编写异步代码以提高程序的性能和效率,在本文中,我们将深入探讨Python中协程的实现方式,包括生成器、asyncio库和async/await关键字,我们还会提供详细的示例代码,帮助您理解和应用协程,需要的朋友可以参考下
    2023-11-11
  • Python中支持向量机SVM的使用方法详解

    Python中支持向量机SVM的使用方法详解

    这篇文章主要为大家详细介绍了Python中支持向量机SVM的使用方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12

最新评论