Keras使用ImageNet上预训练的模型方式

 更新时间:2020年05月23日 15:09:06   作者:breeze5428  
这篇文章主要介绍了Keras使用ImageNet上预训练的模型方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python异步爬虫实现原理与知识总结

    Python异步爬虫实现原理与知识总结

    之前有很多小伙伴想看Python异步爬虫的有关知识总结,这次它来了,文中有非常详细的代码示例与注释,即使对刚开始学python的小伙伴也很友好,,需要的朋友可以参考下
    2021-05-05
  • python中的Pytorch建模流程汇总

    python中的Pytorch建模流程汇总

    这篇文章主要介绍了python中的Pytorch建模流程汇总,主要帮大家帮助大家梳理神经网络训练的架构,具有一的的参考价值,需要的小伙伴可以参考一下,希望对你的学习有所帮助
    2022-03-03
  • python模拟登陆Tom邮箱示例分享

    python模拟登陆Tom邮箱示例分享

    这篇文章主要介绍了python登陆Tom邮箱的示例,大家参考使用吧
    2014-01-01
  • django 获取字段最大值,最新的记录操作

    django 获取字段最大值,最新的记录操作

    这篇文章主要介绍了django 获取字段最大值,最新的记录操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-08-08
  • python程序调用远程服务的步骤详解

    python程序调用远程服务的步骤详解

    这篇文章主要介绍了python程序调用远程服务的步骤详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2021-03-03
  • pytorch 归一化与反归一化实例

    pytorch 归一化与反归一化实例

    今天小编就为大家分享一篇pytorch 归一化与反归一化实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • kaggle数据分析家庭电力消耗过程详解

    kaggle数据分析家庭电力消耗过程详解

    这篇文章主要为大家介绍了kaggle数据分析家庭电力消耗示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-12-12
  • Python利用Beautiful Soup模块创建对象详解

    Python利用Beautiful Soup模块创建对象详解

    这篇文章主要介绍了Python利用Beautiful Soup模块创建对象的相关资料,文中介绍的非常详细,相信对大家具有一定的参考价值,需要的朋友们下面来一起看看吧。
    2017-03-03
  • Python中正反斜杠(‘/’和‘\’)的意义与用法

    Python中正反斜杠(‘/’和‘\’)的意义与用法

    这篇文章主要给大家介绍了关于Python中正反斜杠(‘/’和‘\’)的意义与使用方法,文中通过示例代码介绍的非常详细,对大家学习或者使用Python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-08-08
  • 使用apidoc管理RESTful风格Flask项目接口文档方法

    使用apidoc管理RESTful风格Flask项目接口文档方法

    下面小编就为大家分享一篇使用apidoc管理RESTful风格Flask项目接口文档方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-02-02

最新评论