pytorch模型保存与加载中的一些问题实战记录

 更新时间:2022年10月28日 12:37:20   作者:colourmind  
一般来说,保存模型是把参数全部用model.cpu().state_dict(),然后加载模型时一般用model.load_state_dict(torch.load(model_path)),下面这篇文章主要给大家介绍了关于pytorch模型保存与加载中的一些问题实战记录,需要的朋友可以参考下

前言

最近使用pytorch训练模型,保存模型后再次加载使用出现了一些问题。记录一下解决方案!

一、torch中模型保存和加载的方式

1、模型参数和模型结构保存和加载

torch.save(model,path)
torch.load(path)

2、只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

二、torch中模型保存和加载出现的问题

1、单卡模型下保存模型结构和参数后加载出现的问题

模型保存的时候会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析它然后装载参数;当把模型定义文件路径修改以后,使用torch.load(path)就会报错。

把model文件夹修改为models后,再加载就会报错。

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

这种保存完整模型结构和参数的方式,一定不要改动模型定义文件路径

2、多卡机器单卡训练模型保存后在单卡机器上加载会报错

在多卡机器上有多张显卡0号开始,现在模型在n>=1上的显卡训练保存后,拷贝在单卡机器上加载

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

会出现cuda device不匹配的问题——你保存的模代码段 小部件型是使用的cuda1,那么采用torch.load()打开的时候,会默认的去寻找cuda1,然后把模型加载到该设备上。这个时候可以直接使用map_location来解决,把模型加载到CPU上即可。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))

3、多卡训练模型保存模型结构和参数后加载出现的问题

当用多GPU同时训练模型之后,不管是采用模型结构和参数一起保存还是单独保存模型参数,然后在单卡下加载都会出现问题

a、模型结构和参数一起保然后在加载

torch.distributed.init_process_group(backend='nccl')

模型训练的时候采用上述多进程的方式,所以你在加载的时候也要声明,不然就会报错。

b、单独保存模型参数

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)

同样会出现问题,不过这里出现的问题是参数字典的key和模型定义的key不一样

原因是多GPU训练下,使用分布式训练的时候会给模型进行一个包装,代码如下:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)

包装前的模型结构:

包装后的模型

在外层多了DistributedDataParallel以及module,所以才会导致在单卡环境下加载模型权重的时候出现权重的keys不一致。

三、正确的保存模型和加载的方法

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)

这样就是比较好的范式,加载不会出错。

总结

到此这篇关于pytorch模型保存与加载中的一些问题的文章就介绍到这了,更多相关pytorch模型保存与加载内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • python 模拟创建seafile 目录操作示例

    python 模拟创建seafile 目录操作示例

    这篇文章主要介绍了python 模拟创建seafile 目录操作,结合实例形式详细分析了Python模拟创建seafile 目录相关操作技巧,需要的朋友可以参考下
    2019-09-09
  • Django 事务回滚的具体实现

    Django 事务回滚的具体实现

    本文主要介绍了Django 事务回滚的具体实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python实现TXT数据转三维矩阵

    Python实现TXT数据转三维矩阵

    在数据处理和分析中,将文本文件中的数据转换为三维矩阵是一个常见的任务,本文将详细介绍如何使用Python实现这一任务,感兴趣的小伙伴可以了解下
    2024-01-01
  • Python脚本处理空格的方法

    Python脚本处理空格的方法

    这篇文章主要介绍了Python脚本处理空格的方法,解决方案非常简单,但是好多朋友都不知道,下面小编把解决方案分享到脚本之家平台,供大家参考
    2016-08-08
  • Python设计模式之组合模式原理与用法实例分析

    Python设计模式之组合模式原理与用法实例分析

    这篇文章主要介绍了Python设计模式之组合模式,结合具体实例形式分析了Python组合模式的相关概念、原理、定义及使用方法,需要的朋友可以参考下
    2019-01-01
  • 8种用Python实现线性回归的方法对比详解

    8种用Python实现线性回归的方法对比详解

    这篇文章主要介绍了8种用Python实现线性回归的方法对比详解,说到如何用Python执行线性回归,大部分人会立刻想到用sklearn的linear_model,但事实是,Python至少有8种执行线性回归的方法,sklearn并不是最高效的,需要的朋友可以参考下
    2019-07-07
  • tensorflow之并行读入数据详解

    tensorflow之并行读入数据详解

    今天小编就为大家分享一篇tensorflow之并行读入数据详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • python pyinstaller打包exe报错的解决方法

    python pyinstaller打包exe报错的解决方法

    这篇文章主要给大家介绍了关于python pyinstaller打包exe报错的解决方法,文中通过示例代码介绍的非常详细,对大家的学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-11-11
  • pygame实现俄罗斯方块游戏

    pygame实现俄罗斯方块游戏

    这篇文章主要为大家详细介绍了pygame实现俄罗斯方块游戏,代码注释详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-06-06
  • Win10下Python环境搭建与配置教程

    Win10下Python环境搭建与配置教程

    这篇文章主要为大家详细介绍了Windows10下Python环境搭建与配置,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2016-11-11

最新评论