pytorch GPU和CPU模型相互加载方式

 更新时间:2024年09月09日 10:47:51   作者:Overboom  
在PyTorch中,保存和加载模型有两种主要方式:直接保存整个模型结构加权重,或者只保存模型的参数,直接保存整个模型的方法简单,但不够灵活,且可能存在模型结构不一致的风险,推荐的做法是只保存模型参数,这种方法需要在加载前定义与原模型结构相同的模型

1 pytorch保存模型的两种方式

1.1 直接保存模型并读取

# 创建你的模型实例对象: model
model = net()
## 保存模型
torch.save(model, 'model_name.pth')

## 读取模型
model = torch.load('model_name.pth')

1.2 只保存模型中的参数并读取

## 保存模型
torch.save({'model': model.state_dict()}, 'model_name.pth')

## 读取模型
model = net()
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])
  • 第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。
  • 第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

如何保存模型决定了如何读取模型,一般来选择第二种来保存和读取。

2 GPU / CPU模型相互加载

2.1 单个CPU和单个GPU模型加载

pytorch 允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上。

加载模型参数的时候,在GPU和CPU训练的模型是不一样的,这两种模型是不能混为一谈的,下面分情况进行操作说明。

情况一:CPU -> CPU, GPU -> GPU

  • GPU训练的模型,在GPU上使用;
  • CPU训练的模型,在CPU上使用,

这种情况下我们都只用直接用下面的语句即可:

torch.load('model_dict.pth')

情况二:GPU -> CPG/GPU

GPU训练的模型,不知道放在CPU还是GPU运行,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的GPU模型权重文件
weights_path = 'model_gpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的CUDA设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
weights = torch.load(weights_path, map_location=device)
model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("Model is successfully loaded and can be used on a", device.type, "!")

情况三:CPU -> CPG/GPU

模型是在CPU上训练的,但不确定要在CPU还是GPU上运行时,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的CPU模型权重文件
weights_path = 'model_cpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的CUDA设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
if device.type == 'cuda':
    model.to(device)
    weights = torch.load(weights_path, map_location=device)
else:
    weights = torch.load(weights_path, map_location='cpu')

model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("Model is successfully loaded and can be used on a", device.type, "!")

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于Python制作一个端午节相关的小游戏

    基于Python制作一个端午节相关的小游戏

    端午节快乐,今天我将为大家带来一篇有关端午节的编程文章,希望能够为大家献上一份小小的惊喜,我们将会使用Python来实现一个与端午粽子相关的小应用程序,在本文中,我将会介绍如何用Python代码制做一个“粽子拆解器”,感兴趣的小伙伴欢迎阅读
    2023-06-06
  • 详解python数据结构之队列Queue

    详解python数据结构之队列Queue

    这篇文章主要介绍了python数据结构之队列Queue,文中有非常详细的代码示例,对正在学习python的小伙伴们有很好的帮助,需要的朋友可以参考下
    2021-05-05
  • 基于Python编写词云软件并显示分词结果

    基于Python编写词云软件并显示分词结果

    这篇文章主要为大家详细介绍了如何基于Python编写一个简单的词云制作软件并显示分词结果,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的小伙伴可以了解一下
    2023-10-10
  • Python基础之logging模块知识总结

    Python基础之logging模块知识总结

    用Python写代码的时候,在想看的地方写个print xx 就能在控制台上显示打印信息,这样子就能知道它是什么了,但是当我需要看大量的地方或者在一个文件中查看的时候,这时候print就不大方便了,所以Python引入了logging模块来记录我想要的信息,需要的朋友可以参考下
    2021-05-05
  • Python制作数据预测集成工具(值得收藏)

    Python制作数据预测集成工具(值得收藏)

    这篇文章主要介绍了Python如何制作数据预测集成工具,帮助大家进行大数据预测,感兴趣的朋友可以了解下
    2020-08-08
  • python超简单解决约瑟夫环问题

    python超简单解决约瑟夫环问题

    这篇文章主要介绍了python超简单解决约瑟夫环问题的方法,详细描述的约瑟夫环问题的描述与Python解决方法,需要的朋友可以参考下
    2015-05-05
  • 基于并发服务器几种实现方法(总结)

    基于并发服务器几种实现方法(总结)

    下面小编就为大家分享一篇基于并发服务器几种实现方法(总结),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2017-12-12
  • python实现层次聚类的方法

    python实现层次聚类的方法

    层次聚类就是一层一层的进行聚类,可以由上向下把大的类别(cluster)分割,叫作分裂法,这篇文章主要介绍了python实现层次聚类的方法,需要的朋友可以参考下
    2021-11-11
  • Python自然语言处理之切分算法详解

    Python自然语言处理之切分算法详解

    这篇文章主要介绍了Python自然语言处理之切分算法详解,文中有非常详细的代码示例,对正在学习python的小伙伴们有非常好的帮助,需要的朋友可以参考下
    2021-04-04
  • Python 使用 MySQL 数据库进行事务处理完整示例

    Python 使用 MySQL 数据库进行事务处理完整示例

    本文介绍了Python中使用MySQL进行事务处理的基本概念和步骤,包括事务的核心概念(ACID原则)、事务处理代码示例、关键操作解释以及拓展场景,感兴趣的朋友跟随小编一起看看吧
    2026-01-01

最新评论