pytorch模型保存到本地后,如何实现继续训练

 更新时间:2024年09月09日 15:05:08   作者:hejp_123  
在PyTorch中,保存和加载模型对于实现模型训练的中断和恢复非常有用,保存模型主要有两种方式:一是保存整个模型包括结构与参数;二是仅保存模型的state_dict,加载模型时,若保存了整个模型则直接加载,若仅保存了state_dict,则需先实例化模型结构后加载

在 PyTorch 中,你可以通过以下步骤保存和加载模型,然后继续训练:

1.保存模型

通常有两种方式来保存模型:

保存整个模型(包括网络结构、权重等):

torch.save(model, 'model.pth')

只保存模型的state_dict(只包含权重参数),推荐使用这种方式,因为这样可以节省存储空间,并且在加载时更灵活:

torch.save(model.state_dict(), 'model_weights.pth')

2.加载模型

对应地,也有两种方式来加载模型:

如果你之前保存了整个模型,可以直接通过下面的方式加载:

model = torch.load('model.pth')

如果你之前只保存了state_dict,需要先实例化一个与原模型结构相同的模型,然后通过load_state_dict()方法加载权重:

# 实例化一个与原模型结构相同的模型
model = YourModelClass()

# 加载保存的state_dict
model.load_state_dict(torch.load('model_weights.pth'))

# 确保将模型转移到正确的设备上(例如GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

3.继续训练

加载完模型后,就可以继续训练了。

确保你已经定义了损失函数和优化器,并且它们的状态也要正确加载(如果你之前保存了它们的话)。然后,按照正常的训练流程进行即可

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 如果之前保存了优化器状态,也可以加载
optimizer.load_state_dict(torch.load('optimizer.pth'))

# 开始训练
for epoch in range(num_epochs):
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

这样,你就可以从上次保存的地方继续训练模型了。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 基于Python编写一个词云制作程序

    基于Python编写一个词云制作程序

    这篇文章主要为大家详细介绍了如何基于Python编写一个简单的词云制作程序,文中的示例代码讲解详细,具有一定的学习价值,感兴趣的小伙伴可以了解一下
    2023-10-10
  • python实现日历效果

    python实现日历效果

    这篇文章主要为大家详细介绍了python实现日历效果,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-08-08
  • 向量化操作改进数据分析工作流的Pandas Numpy示例分析

    向量化操作改进数据分析工作流的Pandas Numpy示例分析

    这篇文章主要介绍了向量化操作改进数据分析工作流的Pandas Numpy示例分析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-10-10
  • Python Requests.post()请求失败时的retry设置方式

    Python Requests.post()请求失败时的retry设置方式

    这篇文章主要介绍了Python Requests.post()请求失败时的retry设置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • python中Pytest常用的插件

    python中Pytest常用的插件

    这篇文章主要介绍了python中Pytest常用的插件,Pytest是Python的一种单元测试框架,与unittest相比,使用起来更简洁、效率更高,也是目前大部分使用python编写测试用例的小伙伴们的第一选择了
    2022-06-06
  • Python单元测试简单示例

    Python单元测试简单示例

    这篇文章主要介绍了Python单元测试,结合实例形式分析了Python单元测试的简单定义、使用方法及相关操作注意事项,需要的朋友可以参考下
    2018-07-07
  • 如何使用django的MTV开发模式返回一个网页

    如何使用django的MTV开发模式返回一个网页

    这篇文章主要介绍了如何使用django的MTV开发模式返回一个网页,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Python eval函数原理及用法解析

    Python eval函数原理及用法解析

    这篇文章主要介绍了Python eval函数原理及用法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • Python入门教程1. 基本运算【四则运算、变量、math模块等】

    Python入门教程1. 基本运算【四则运算、变量、math模块等】

    这篇文章主要介绍了Python教程的基本运算,包括四则运算、变量的使用与类型检测、math模块等,并附带了相关说明,代码备有较为详尽的说明,便于理解,需要的朋友可以参考下
    2018-10-10
  • 基于Python+OpenCV实现自动扫雷功能

    基于Python+OpenCV实现自动扫雷功能

    相信许多人很早就知道有扫雷这么一款经典的游(显卡测试)戏(软件),扫雷作为一款在Windows9x时代就已经诞生的经典游戏,从过去到现在依然都有着它独特的魅力,所以本文小编给大家介绍了如何使用Python+OpenCV实现自动扫雷效果,感兴趣的朋友可以参考下
    2023-12-12

最新评论