pytorch载入预训练模型后,实现训练指定层

 更新时间:2020年01月06日 09:46:50   作者:慕白-  
今天小编就为大家分享一篇pytorch载入预训练模型后,实现训练指定层,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

参考:

pytorch官方文档

https://www.jb51.net/article/134943.htm

相关文章

  • Python中typing模块的具体使用

    Python中typing模块的具体使用

    本文主要介绍了Python中typing模块的具体使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-05-05
  • Python中的is和==比较两个对象的两种方法

    Python中的is和==比较两个对象的两种方法

    这篇文章主要介绍了Python中的is和==比较两个对象的两种方法的相关资料,希望通过本文能帮助到大家,需要的朋友可以参考下
    2017-09-09
  • Python数据清洗工具之Numpy的基本操作

    Python数据清洗工具之Numpy的基本操作

    Numpy的操作对象是一个ndarray,所以在使用这个库进行计算的时候需要将数据进行转化,这篇文章主要介绍了Python数据清洗工具之Numpy的基本操作,需要的朋友可以参考下
    2021-04-04
  • python wxpython 实现界面跳转功能

    python wxpython 实现界面跳转功能

    wxpython没提供界面跳转的方式,所以就需要借助threading模块,本文给大家分享python wxpython 实现界面跳转功能,感兴趣的朋友跟随小编一起看看吧
    2019-12-12
  • python加密解密库cryptography使用openSSL生成的密匙加密解密

    python加密解密库cryptography使用openSSL生成的密匙加密解密

    这篇文章主要介绍了python加密解密库cryptography使用openSSL生成的密匙加密解密,需要的朋友可以参考下
    2020-02-02
  • Python多进程池 multiprocessing Pool用法示例

    Python多进程池 multiprocessing Pool用法示例

    这篇文章主要介绍了Python多进程池 multiprocessing Pool用法,结合实例形式分析了多进程池 multiprocessing Pool相关概念、原理及简单使用技巧,需要的朋友可以参考下
    2018-09-09
  • python字符串替换re.sub()方法解析

    python字符串替换re.sub()方法解析

    这篇文章主要介绍了python字符串替换re.sub()方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09
  • Python 任务自动化工具nox 的配置与 API详情

    Python 任务自动化工具nox 的配置与 API详情

    这篇文章主要介绍了Python 任务自动化工具nox 的配置与 API详情,Nox 会话是通过被@nox.session装饰的标准 Python 函数来配置的,具体详情下文相关介绍需要的小伙伴可以参考一下
    2022-07-07
  • python查看自己安装的所有库并导出的命令

    python查看自己安装的所有库并导出的命令

    这篇文章主要介绍了python查看自己安装的所有库并导出,主要包括查看安装的库通过命令查询,导出库安装文件执行命令,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-06-06
  • python 读取txt中每行数据,并且保存到excel中的实例

    python 读取txt中每行数据,并且保存到excel中的实例

    下面小编就为大家分享一篇python 读取txt中每行数据,并且保存到excel中的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04

最新评论