深入理解Pytorch微调torchvision模型

 更新时间:2021年11月10日 16:58:22   作者:柚子味的羊  
PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速,同时还支持动态神经网络,这一点是现在很多主流框架如TensorFlow都不支持的

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

  • 微调:从预训练模型开始,更新我们新任务的所有模型参数,实质上是重新训练整个模型。
  • 特征提取:从预训练模型开始,仅更新从中导出预测的最终图层权重。它被称为特征提取,因为我们使用预训练的CNN作为固定 的特征提取器,并且仅改变输出层。

通常这两种迁移学习方法都会遵循一下步骤:

  • 初始化预训练模型
  • 重组最后一层,使其具有与新数据集类别数相同的输出数
  • 为优化算法定义想要的训练期间更新的参数
  • 运行训练步骤

二、导入相关包

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision 
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("Pytorch version:",torch.__version__)
print("torchvision version:",torchvision.__version__)

运行结果

在这里插入图片描述

三、数据输入

数据集——>我在这里

链接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取码:1234

#%%输入
data_dir="D:\Python\Pytorch\data\hymenoptera_data"
# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]
model_name='squeezenet'
# 数据集中类别数量
num_classes=2
# 训练的批量大小
batch_size=8
# 训练epoch数
num_epochs=15
# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数
feature_extract=True

四、辅助函数

1、模型训练和验证

  • train_model函数处理给定模型的训练和验证。作为输入,它需要PyTorch模型、数据加载器字典、损失函数、优化器、用于训练和验 证epoch数,以及当模型是初始模型时的布尔标志。
  • is_inception标志用于容纳 Inception v3 模型,因为该体系结构使用辅助输出, 并且整体模型损失涉及辅助输出和最终输出,如此处所述。 这个函数训练指定数量的epoch,并且在每个epoch之后运行完整的验证步骤。它还跟踪最佳性能的模型(从验证准确率方面),并在训练 结束时返回性能最好的模型。在每个epoch之后,打印训练和验证正确率。
#%%模型训练和验证
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
    since=time.time()
    val_acc_history=[]
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch, num_epochs-1))
        print('-'*10)
        # 每个epoch都有一个训练和验证阶段
        for phase in['train','val']:
            if phase=='train':
                model.train()
            else:
                model.eval()
                
            running_loss=0.0
            running_corrects=0
            # 迭代数据
            for inputs,labels in dataloaders[phase]:
                inputs=inputs.to(device)
                labels=labels.to(device)
                # 梯度置零
                optimizer.zero_grad()
                # 向前传播
                with torch.set_grad_enabled(phase=='train'):
                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出
                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出
                    if is_inception and phase=='train':
                        outputs,aux_outputs=model(inputs)
                        loss1=criterion(outputs,labels)
                        loss2=criterion(aux_outputs,labels)
                        loss=loss1+0.4*loss2
                    else:
                        outputs=model(inputs)
                        loss=criterion(outputs,labels)
                        
                    _,preds=torch.max(outputs,1)
                    
                    if phase=='train':
                        loss.backward()
                        optimizer.step()
                        
                # 添加
                running_loss+=loss.item()*inputs.size(0)
                running_corrects+=torch.sum(preds==labels.data)
                
            epoch_loss=running_loss/len(dataloaders[phase].dataset)
            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
            
            print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))
            
            if phase=='train' and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
            if phase=='val':
                val_acc_history.append(epoch_acc)
            
        print()

    time_elapsed=time.time()-since
    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
    print('best val acc:{:.4f}'.format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model,val_acc_history

2、设置模型参数的'.requires_grad属性'

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性
def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.require_grad=False

靓仔今天先去跑步了,再不跑来不及了,先更这么多,后续明天继续~(感谢有人没有催更!感谢监督!希望继续监督!)

以上就是深入理解Pytorch微调torchvision模型的详细内容,更多关于Pytorch torchvision模型的资料请关注脚本之家其它相关文章!

相关文章

  • Python 列表(List) 的三种遍历方法实例 详解

    Python 列表(List) 的三种遍历方法实例 详解

    这篇文章主要介绍了Python 列表(List) 的三种遍历方法实例 详解的相关资料,需要的朋友可以参考下
    2017-04-04
  • pandas库中to_datetime()方法的使用解析

    pandas库中to_datetime()方法的使用解析

    这篇文章主要介绍了pandas库中to_datetime()方法的使用解析,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07
  • 使用Python操作文件系统的方法

    使用Python操作文件系统的方法

    Python提供了许多内置库来处理文件系统,如os、shutil和pathlib等,这些库可以帮助你创建、删除、读取、写入文件和目录,这篇文章主要介绍了使用Python操作文件系统,需要的朋友可以参考下
    2023-07-07
  • 关于Gradio中Button用法及事件监听器click方法使用

    关于Gradio中Button用法及事件监听器click方法使用

    介绍了在Gradio中使用Button组件和事件监听器的click方法,通过一个简单的示例展示了如何实现点击按钮输出一行文字的功能,在实际项目中遇到了一个错误,经过排查和请教室友后,发现问题出在inputs参数的传递上,需要传入一个包含输入组件的列表
    2024-11-11
  • Python calendar模块详情

    Python calendar模块详情

    这篇文章主要介绍了 Python calendar模块,Python 专门为了处理日历提供了calendar日历模块,下面文章基于time模块和datetime模块展开,具有一定的参考价值,需要的朋友可以参考一下
    2021-11-11
  • Python爬虫lxml库处理XML和HTML文档

    Python爬虫lxml库处理XML和HTML文档

    在当今信息爆炸的时代,网络上的数据量庞大而繁杂,为了高效地从网页中提取信息,Python爬虫工程师们需要强大而灵活的工具,其中,lxml库凭借其卓越的性能和丰富的功能成为Python爬虫领域的不可或缺的工具之一,本文将深入介绍lxml库的各个方面,充分掌握这个强大的爬虫利器
    2023-12-12
  • 一文讲解python中的继承冲突及继承顺序

    一文讲解python中的继承冲突及继承顺序

    python支持多继承,如果子类没有重写方法,则默认会调用父类的方法,本文主要介绍了一文讲解python中的继承冲突及继承顺序,具有一定的参考价值,感兴趣的可以了解一下
    2024-03-03
  • GoReplay中间件python版本使用教程

    GoReplay中间件python版本使用教程

    GoReplay 是一个用于网络流量录制和回放的工具,它可以用于测试和优化分布式系统,这篇文章主要介绍了GoReplay中间件python版本使用教程,需要的朋友可以参考下
    2024-02-02
  • python DataFrame的shift()方法的使用

    python DataFrame的shift()方法的使用

    在python数据分析中,可以使用shift()方法对DataFrame对象的数据进行位置的前滞、后滞移动,本文主要介绍了python DataFrame的shift()方法的使用,感兴趣的可以了解一下
    2022-03-03
  • 使用python对pdf文件进行加密等操作

    使用python对pdf文件进行加密等操作

    这篇文章主要为大家详细介绍了使用python对pdf文件进行加密等操作的相关知识,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2024-12-12

最新评论