pytorch如何保存训练模型参数并实现继续训练

 更新时间:2023年09月11日 14:56:36   作者:回炉重造P  
这篇文章主要介绍了pytorch如何保存训练模型参数并实现继续训练问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

最近的想法是在推荐模型中考虑根据用户对推荐结果的后续选择,利用已训练的offline预训练模型参数来更新新的结果。

简单记录一下中途保存参数和后续使用不同数据训练的方法。

简单模型和训练数据

先准备一个简单模型,简单两层linear出个分类结果。

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(64, 32)
        self.linear1 = nn.Linear(32, 10)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear1(x)
        return x

准备训练用数据,这里直接随机两份,同时给出配套的十个类的分类label结果。

要注意的是 crossEntropy 交叉熵只认 long 以上的tensor,label记得转一下类型。

    rand1 = torch.rand((100, 64)).to(torch.float)
    label1 = np.random.randint(0, 10, size=100)
    label1 = torch.from_numpy(label1).to(torch.long)
    rand2 = torch.rand((100, 64)).to(torch.float)
    label2 = np.random.randint(0, 10, size=100)
    label2 = torch.from_numpy(label2).to(torch.long)

训练简单使用交叉熵,优化器Adam。

    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss = nn.CrossEntropyLoss()
    iteration = 100
    for i in range(iteration):
        output = model(rand1)
        my_loss = loss(output, label1)
        optimizer.zero_grad()
        my_loss.backward()
        optimizer.step()
        print("iteration:{} loss:{}".format(i, my_loss))

反正能跑起来:

在这里插入图片描述

保存与读取训练参数结果的方法

关键的保存方法,可以分为两种,一种是直接把模型整体保存:

torch.save(model, save_path)

两个参数,模型和保存目录。不过这种不常用,如果模型变化或者只需要其中一部分参数就不太灵活。

常用方法的是将需要的模型或优化器参数取出以字典形式存储,这样可以在使用时初始化相关模型,读入对应参数即可。

def save_model(save_path, iteration, optimizer, model):
    torch.save({'iteration': iteration,
                'optimizer_dict': optimizer.state_dict(),
                'model_dict': model.state_dict()},
                save_path)
    print("model save success")

分别存储训练循环次数,优化器设置和模型参数结果。

初始化模型,读取参数并设置:

def load_model(save_name, optimizer, model):
    model_data = torch.load(save_name)
    model.load_state_dict(model_data['model_dict'])
    optimizer.load_state_dict(model_data['optimizer_dict'])
    print("model load success")

初始化新模型:

    path = "net.dict"
    save_model(path, iteration, optimizer, model)
    print(model.state_dict()['linear.weight'])
    new_model = MyModel()
    new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
    load_model(path, new_optimizer, new_model)
    print(new_model.state_dict()['linear.weight'])

输出第一个linear层的参数看看,确实相同,参数成功读取上了。注意optimizer的初始化对应模型别写错了。

在这里插入图片描述

之后用新模型继续训练试试:

    for i in range(iteration):
        output = new_model(rand2)
        my_loss = loss(output, label2)
        new_optimizer.zero_grad()
        my_loss.backward()
        new_optimizer.step()
        print("iteration:{} loss:{}".format(i, my_loss))

能成功训练。

在这里插入图片描述

变化学习率的保存

上面的demo只用了固定的学习率来做实验。

如果使用了 scheduler 来变化步长,只要保存 scheduler state_dict ,之后对新初始化的 scheduler 设置对应的当前循环次数即可。

# 存储时
'scheduler': scheduler.state_dict()
# 读取时
scheduler.load_state_dict(checkpoint['lr_schedule'])

scheduler的使用可以看看我之前整理的文章:利用scheduler实现learning-rate学习率动态变化

总结

这次主要是整理了一下pytorch模型参数的整体保存方法,来实现新数据的后续训练或直接作为offline预训练模型来使用。

不过后续数据分布不同的话感觉效果会很差啊…

也不知道能不能用什么算法修改下权重来贴合新的数据,找点多次训练优化论文看看好了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用python绘制常用的图表

    使用python绘制常用的图表

    本文给大家介绍的是如何使用Python根据Excel表格数据绘制不同的图表的方法,非常的详细,有相同需求的小伙伴可以参考下
    2016-08-08
  • 简单聊聊PyTorch里面的torch.nn.Parameter()

    简单聊聊PyTorch里面的torch.nn.Parameter()

    torch.nn.parameter是一个被用作神经网络模块参数的tensor,这是一种tensor的子类,下面这篇文章主要给大家介绍了关于PyTorch里面的torch.nn.Parameter()的相关资料,需要的朋友可以参考下
    2022-02-02
  • python xlwt如何设置单元格的自定义背景颜色

    python xlwt如何设置单元格的自定义背景颜色

    这篇文章主要介绍了python xlwt如何设置单元格的自定义背景颜色,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python批量发送post请求的实现代码

    Python批量发送post请求的实现代码

    昨天学了一天的Python(我的生产语言是java,也可以写一些shell脚本,算有一点点基础),今天有一个应用场景,就正好练手了
    2018-05-05
  • python使用smtplib模块发送邮件

    python使用smtplib模块发送邮件

    这篇文章主要为大家详细介绍了python使用smtplib模块发送邮件,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-12-12
  • Python tkinter控件样式详解

    Python tkinter控件样式详解

    tkinter对控件的诸多属性提供了可定制的功能,下面以最常用的按钮作为示例,集中展示其样式特点,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-09-09
  • 跟老齐学Python之集合(set)

    跟老齐学Python之集合(set)

    本文主要内容是要向各位介绍一种新的数据类型:集合(set).彻底晕倒了,到底python有多少个数据类型呢?又多出来了一个.
    2014-09-09
  • 如何提取Playwright录制文件中的元素定位信息

    如何提取Playwright录制文件中的元素定位信息

    最近在学习Playwright自动化测试,本文主要介绍了如何提取Playwright录制文件中的元素定位信息,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-12-12
  • Python实现随机取一个矩阵数组的某几行

    Python实现随机取一个矩阵数组的某几行

    今天小编就为大家分享一篇Python实现随机取一个矩阵数组的某几行,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-11-11
  • python+opencv实现阈值分割

    python+opencv实现阈值分割

    这篇文章主要为大家详细介绍了python+opencv实现阈值分割的相关代码,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-12-12

最新评论