使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

 更新时间:2020年01月18日 16:14:58   作者:sjtu_leexx  
今天小编就为大家分享一篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 深入理解python中的ThreadLocal

    深入理解python中的ThreadLocal

    本文主要介绍了深入理解python中的ThreadLocal,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • 如何使用Django默认的Auth权限管理系统

    如何使用Django默认的Auth权限管理系统

    本文主要介绍了如何使用Django默认的Auth权限管理系统,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python中的文件和目录操作实现代码

    Python中的文件和目录操作实现代码

    对于文件和目录的处理,虽然可以通过操作系统命令来完成,但是Python语言为了便于开发人员以编程的方式处理相关工作,提供了许多处理文件和目录的内置函数。重要的是,这些函数无论是在Unix、Windows还是Macintosh平台上,它们的使用方式是完全一致的。
    2011-03-03
  • Pycharm保存不能自动同步到远程服务器的解决方法

    Pycharm保存不能自动同步到远程服务器的解决方法

    今天小编就为大家分享一篇Pycharm保存不能自动同步到远程服务器的解决方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-06-06
  • pytorch动态神经网络(拟合)实现

    pytorch动态神经网络(拟合)实现

    这篇文章主要介绍了pytorch动态神经网络(拟合)实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • Python中用startswith()函数判断字符串开头的教程

    Python中用startswith()函数判断字符串开头的教程

    这篇文章主要介绍了Python中用startswith()函数判断字符串开头的教程,startswith()函数的使用是Python学习中的基础知识,本文列举了一些不同情况下的使用结果,需要的朋友可以参考下
    2015-04-04
  • Python使用googletrans报错的解决方法

    Python使用googletrans报错的解决方法

    这篇文章主要给大家介绍了关于Python使用googletrans报错的解决方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2018-09-09
  • 33个Python爬虫项目实战(推荐)

    33个Python爬虫项目实战(推荐)

    这篇文章主要介绍了33个Python爬虫项目实战,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-07-07
  • 对python实现二维函数高次拟合的示例详解

    对python实现二维函数高次拟合的示例详解

    今天小编就为大家分享一篇对python实现二维函数高次拟合的示例详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • pytest多线程与多设备并发appium

    pytest多线程与多设备并发appium

    这篇文章介绍了pytest多线程与多设备并发appium,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-06-06

最新评论