pytorch GPU和CPU模型相互加载方式

 更新时间:2024年09月09日 10:47:51   作者:Overboom  
在PyTorch中,保存和加载模型有两种主要方式:直接保存整个模型结构加权重,或者只保存模型的参数,直接保存整个模型的方法简单,但不够灵活,且可能存在模型结构不一致的风险,推荐的做法是只保存模型参数,这种方法需要在加载前定义与原模型结构相同的模型

1 pytorch保存模型的两种方式

1.1 直接保存模型并读取

# 创建你的模型实例对象: model
model = net()
## 保存模型
torch.save(model, 'model_name.pth')

## 读取模型
model = torch.load('model_name.pth')

1.2 只保存模型中的参数并读取

## 保存模型
torch.save({'model': model.state_dict()}, 'model_name.pth')

## 读取模型
model = net()
state_dict = torch.load('model_name.pth')
model.load_state_dict(state_dict['model'])
  • 第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。
  • 第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

如何保存模型决定了如何读取模型,一般来选择第二种来保存和读取。

2 GPU / CPU模型相互加载

2.1 单个CPU和单个GPU模型加载

pytorch 允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上。

加载模型参数的时候,在GPU和CPU训练的模型是不一样的,这两种模型是不能混为一谈的,下面分情况进行操作说明。

情况一:CPU -> CPU, GPU -> GPU

  • GPU训练的模型,在GPU上使用;
  • CPU训练的模型,在CPU上使用,

这种情况下我们都只用直接用下面的语句即可:

torch.load('model_dict.pth')

情况二:GPU -> CPG/GPU

GPU训练的模型,不知道放在CPU还是GPU运行,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的GPU模型权重文件
weights_path = 'model_gpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的CUDA设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
weights = torch.load(weights_path, map_location=device)
model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("Model is successfully loaded and can be used on a", device.type, "!")

情况三:CPU -> CPG/GPU

模型是在CPU上训练的,但不确定要在CPU还是GPU上运行时,两种情况都要考虑

import torch
from torchvision import models

# 加载预训练的CPU模型权重文件
weights_path = 'model_cpu.pth'

# 定义一个与原模型结构相同的新模型
model = models.resnet50()

# 检查是否有可用的CUDA设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将权重映射到相应的设备内存并加载到模型中
if device.type == 'cuda':
    model.to(device)
    weights = torch.load(weights_path, map_location=device)
else:
    weights = torch.load(weights_path, map_location='cpu')

model.load_state_dict(weights)

# 设置为评估模式
model.eval()

print("Model is successfully loaded and can be used on a", device.type, "!")

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 利用pyinstaller将py文件打包为exe的方法

    利用pyinstaller将py文件打包为exe的方法

    本篇文章主要介绍了利用pyinstaller将py文件打包为exe的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-05-05
  • python缺失值的解决方法总结

    python缺失值的解决方法总结

    在本篇文章里小编给大家整理的是一篇关于python缺失值的解决方法总结,有需要的朋友们可以跟着学习下。
    2021-06-06
  • 从零学Python之hello world

    从零学Python之hello world

    从今天开始讲陆续发布一系列python基础教程,让新手更快更好的入门。
    2014-05-05
  • 详解用python -m http.server搭一个简易的本地局域网

    详解用python -m http.server搭一个简易的本地局域网

    这篇文章主要介绍了详解用python -m http.server搭一个简易的本地局域网,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • Python FTP操作类代码分享

    Python FTP操作类代码分享

    这篇文章主要介绍了Python FTP操作类,实现自动下载、自动上传,并可以递归目录操作,需要的朋友可以参考下
    2014-05-05
  • python装饰器实例大详解

    python装饰器实例大详解

    这篇文章主要介绍了python装饰器实例大详解,非常不错,具有参考借鉴价值,需要的朋友可以参考下
    2017-10-10
  • 用selenium解决滑块验证码的实现步骤

    用selenium解决滑块验证码的实现步骤

    验证码作为一种自然人的机器人的判别工具,被广泛的用于各种防止程序做自动化的场景中,下面这篇文章主要给大家介绍了关于用selenium解决滑块验证码的实现步骤,需要的朋友可以参考下
    2023-02-02
  • Python利用卡方Chi特征检验实现提取关键文本特征

    Python利用卡方Chi特征检验实现提取关键文本特征

    卡方检验最基本的思想就是通过观察实际值与理论值的偏差来确定理论的正确与否。本文将利用卡方Chi特征检验实现提取关键文本特征功能,感兴趣的可以了解一下
    2022-12-12
  • Pygame实现简易版趣味小游戏之反弹球

    Pygame实现简易版趣味小游戏之反弹球

    这篇文章主要为大家详细介绍了python实现简易版趣味反弹球游戏,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-03-03
  • python linecache读取行更新的实现

    python linecache读取行更新的实现

    本文主要介绍了python linecache读取行更新的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03

最新评论