Python中torch.load()加载模型以及其map_location参数详解

 更新时间:2022年09月23日 09:55:41   作者:eecspan  
torch.load()作用用来加载torch.save()保存的模型文件,下面这篇文章主要给大家介绍了关于Python中torch.load()加载模型以及其map_location参数的相关资料,需要的朋友可以参考下

参考

TORCH.LOAD

torch.load()

函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

模型的保存

模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。

另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。

具体可参考:PyTorch模型的保存与加载

模型加载中的map_location参数

具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

map_location=None

我们先把state_dict加载进来。

model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)

结果为:

cuda:0

因为保存的时候就是模型就是cuda:0的,所以加载进来也是。

map_location=torch.device()

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
print(next(model.parameters()).device)

结果为:

cpu

模型从cuda:0变成了cpu

map_location={xx:xx}

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
print(next(model.parameters()).device)

结果为:

cuda:1

模型从cuda:0变成了cuda:1

model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})
print(next(model.parameters()).device)

结果为:

cuda:0

模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。

总结

到此这篇关于Python中torch.load()加载模型以及其map_location参数详解的文章就介绍到这了,更多相关torch.load()加载模型map_location参数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python决策树预测学生成绩等级实现详情

    python决策树预测学生成绩等级实现详情

    这篇文章主要为介绍了python决策树预测学生成绩等级,使用决策树完成学生成绩等级预测,可选取部分或全部特征,分析参数对结果的影响,并进行调参优化,决策树可视化进行调参优化分析
    2022-04-04
  • python列表逆序排列的4种方法

    python列表逆序排列的4种方法

    python中的列表是可以直接进行逆序排列的,本文主要介绍了python列表逆序排列的方法,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2023-05-05
  • Python三维绘图之Matplotlib库的使用方法

    Python三维绘图之Matplotlib库的使用方法

    这篇文章主要给大家介绍了关于Python三维绘图之Matplotlib库的使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • Python实现博客快速备份的脚本分享

    Python实现博客快速备份的脚本分享

    本文针对博客园实现了一个自动备份脚本,可以快速将自己的文章备份成Markdown格式的独立文件,备份后的md文件可以直接放入到hexo博客中,感兴趣的可以了解一下
    2022-09-09
  • 浅谈python在提示符下使用open打开文件失败的原因及解决方法

    浅谈python在提示符下使用open打开文件失败的原因及解决方法

    今天小编就为大家分享一篇浅谈python在提示符下使用open打开文件失败的原因及解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • 解决csv.writer写入文件有多余的空行问题

    解决csv.writer写入文件有多余的空行问题

    今天小编就为大家分享一篇解决csv.writer写入文件有多余的空行问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • Python流程控制 if else实现解析

    Python流程控制 if else实现解析

    这篇文章主要介绍了Python 流程控制 if else实现解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • 浅谈DataFrame和SparkSql取值误区

    浅谈DataFrame和SparkSql取值误区

    今天小编就为大家分享一篇浅谈DataFrame和SparkSql取值误区,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • pycharm设置默认的UTF-8编码模式的方法详解

    pycharm设置默认的UTF-8编码模式的方法详解

    这篇文章主要介绍了pycharm设置默认的UTF-8编码模式,本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-06-06
  • pandas 时间偏移的实现

    pandas 时间偏移的实现

    时间偏移就是在指定时间往前推或者往后推一段时间,即加减一段时间之后的时间,本文使用Python实现,感兴趣的可以了解一下
    2021-08-08

最新评论