深入理解Pytorch微调torchvision模型

 更新时间:2021年11月10日 16:58:22   作者:柚子味的羊  
PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速,同时还支持动态神经网络,这一点是现在很多主流框架如TensorFlow都不支持的

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

  • 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。
  • 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。

通常这两种迁移学习方法都会遵循一下步骤:

  • 初始化预训练模型
  • 重组最后一层,使其具有与新数据集类别数相同的输出数
  • 为优化算法定义想要的训练期间更新的参数
  • 运行训练步骤

二、导入相关包

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision 
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("Pytorch version:",torch.__version__)
print("torchvision version:",torchvision.__version__)

运行结果

在这里插入图片描述

三、数据输入

数据集——>我在这里

链接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取码:1234

#%%输入
data_dir="D:\Python\Pytorch\data\hymenoptera_data"
# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]
model_name='squeezenet'
# 数据集中类别数量
num_classes=2
# 训练的批量大小
batch_size=8
# 训练epoch数
num_epochs=15
# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数
feature_extract=True

四、辅助函数

1、模型训练和验证

  • train_model函数处理给定模型的训练和验证。作为输入,它需要PyTorch模型、数据加载器字典、损失函数、优化器、用于训练和验 证epoch数,以及当模型是初始模型时的布尔标志。
  • is_inception标志用于容纳 Inception v3 模型,因为该体系结构使用辅助输出, 并且整体模型损失涉及辅助输出和最终输出,如此处所述。 这个函数训练指定数量的epoch,并且在每个epoch之后运行完整的验证步骤。它还跟踪最佳性能的模型(从验证准确率方面),并在训练 结束时返回性能最好的模型。在每个epoch之后,打印训练和验证正确率。
#%%模型训练和验证
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
    since=time.time()
    val_acc_history=[]
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs-1))
        print('-'*10)
        # 每个epoch都有一个训练和验证阶段
        for phase in['train','val']:
            if phase=='train':
                model.train()
            else:
                model.eval()
                
            running_loss=0.0
            running_corrects=0
            # 迭代数据
            for inputs,labels in dataloaders[phase]:
                inputs=inputs.to(device)
                labels=labels.to(device)
                # 梯度置零
                optimizer.zero_grad()
                # 向前传播
                with torch.set_grad_enabled(phase=='train'):
                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出
                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出
                    if is_inception and phase=='train':
                        outputs,aux_outputs=model(inputs)
                        loss1=criterion(outputs,labels)
                        loss2=criterion(aux_outputs,labels)
                        loss=loss1+0.4*loss2
                    else:
                        outputs=model(inputs)
                        loss=criterion(outputs,labels)
                        
                    _,preds=torch.max(outputs,1)
                    
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                        
                # 添加
                running_loss+=loss.item()*inputs.size(0)
                running_corrects+=torch.sum(preds==labels.data)
                
            epoch_loss=running_loss/len(dataloaders[phase].dataset)
            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
            
            print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))
            
            if phase=='train' and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
            if phase=='val':
                val_acc_history.append(epoch_acc)
            
        print()

    time_elapsed=time.time()-since
    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('best val acc:{:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model,val_acc_history

2、设置模型参数的'.requires_grad属性'

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性
def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.require_grad=False

靓仔今天先去跑步了,再不跑来不及了,先更这么多,后续明天继续~(感谢有人没有催更!感谢监督!希望继续监督!)

以上就是深入理解Pytorch微调torchvision模型的详细内容,更多关于Pytorch torchvision模型的资料请关注脚本之家其它相关文章!

相关文章

  • python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法

    python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法

    这篇文章主要介绍了python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法,具有很好的参考价值,希望对大家有所帮助。
    2021-06-06
  • 详解如何使用Pandas处理时间序列数据

    详解如何使用Pandas处理时间序列数据

    时间序列数据在数据分析建模中很常见,例如天气预报,空气状态监测,股票交易等金融场景,本文给大家详细介绍了如何使用Pandas处理时间序列数据,文中通过代码示例讲解的非常详细,需要的朋友可以参考下
    2024-01-01
  • Python序列化与反序列化相关知识总结

    Python序列化与反序列化相关知识总结

    今天给大家带来关于python的相关知识,文章围绕着Python序列化与反序列展开,文中有非常详细的介绍,需要的朋友可以参考下
    2021-06-06
  • Python pandas RFM模型应用实例详解

    Python pandas RFM模型应用实例详解

    这篇文章主要介绍了Python pandas RFM模型应用,结合实例形式详细分析了pandas RFM模型的概念、原理、应用及相关操作注意事项,需要的朋友可以参考下
    2019-11-11
  • 使用matplotlib绘制并排柱状图的实战案例

    使用matplotlib绘制并排柱状图的实战案例

    堆积柱状图有堆积柱状图的好处,比如说我们可以很方便地看到多分类总和的趋势,下面这篇文章主要给大家介绍了关于使用matplotlib绘制并排柱状图的相关资料,需要的朋友可以参考下
    2022-07-07
  • python实现web方式logview的方法

    python实现web方式logview的方法

    这篇文章主要介绍了python实现web方式logview的方法,涉及Python基于web模块操作Linux命令的技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-08-08
  • python安装dlib库报错问题及解决方法

    python安装dlib库报错问题及解决方法

    这篇文章主要介绍了python安装dlib库报错问题及解决方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-03-03
  • Flask与数据库的交互插件Flask-Sqlalchemy的使用

    Flask与数据库的交互插件Flask-Sqlalchemy的使用

    在构建Web应用时,与数据库的交互是必不可少的部分,本文主要介绍了Flask与数据库的交互插件Flask-Sqlalchemy的使用,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • 简单说明Python中的装饰器的用法

    简单说明Python中的装饰器的用法

    这篇文章主要简单说明了Python中的装饰器的用法,装饰器在Python的进阶学习中非常重要,示例代码基于Python2.x,需要的朋友可以参考下
    2015-04-04
  • python实现登录与注册系统

    python实现登录与注册系统

    这篇文章主要为大家详细介绍了python实现登录与注册系统,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-11-11

最新评论