python神经网络ResNet50模型的复现详解

 更新时间:2022年05月06日 16:03:47   作者:Bubbliiiing  
这篇文章主要为大家介绍了python神经网络ResNet50模型的复现详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

什么是残差网络

最近看yolo3里面讲到了残差网络,对这个网络结构很感兴趣,于是了解到这个网络结构最初的使用是在ResNet网络里。

Residual net(残差网络):

将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。

意味着后面的特征层的内容会有一部分由其前面的某一层线性贡献。

其结构如下:

深度残差网络的设计是为了克服由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题。

什么是ResNet50模型

ResNet50有两个基本的块,分别名为Conv Block和Identity Block,其中Conv Block输入和输出的维度是不一样的,所以不能连续串联,它的作用是改变网络的维度;

Identity Block输入维度和输出维度相同,可以串联,用于加深网络的。

Conv Block的结构如下:

Identity Block的结构如下:

这两个都是残差网络结构。

总的网络结构如下:

这样看起来可能比较抽象,还有一副很好的我从网上找的图,可以拉到最后面去看哈,放前面太占位置了。

ResNet50网络部分实现代码

#-------------------------------------------------------------#
#   ResNet50的网络部分
#-------------------------------------------------------------#
from __future__ import print_function
import numpy as np
from keras import layers
from keras.layers import Input
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.layers import Activation,BatchNormalization,Flatten
from keras.models import Model
from keras.preprocessing import image
import keras.backend as K
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import decode_predictions
from keras.applications.imagenet_utils import preprocess_input
def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size, padding='same',
               name=conv_name_base + '2b')(x)
    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)
    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)
    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x
def ResNet50(input_shape=[224,224,3],classes=1000):
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)
    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)
    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
    x = AveragePooling2D((7, 7), name='avg_pool')(x)
    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)
    model = Model(img_input, x, name='resnet50')
    model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")
    return model

图片预测

建立网络后,可以用以下的代码进行预测。

if __name__ == '__main__':
    model = ResNet50()
    model.summary()
    img_path = 'elephant.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    print('Input image shape:', x.shape)
    preds = model.predict(x)
    print('Predicted:', decode_predictions(preds))

预测所需的已经训练好的ResNet50模型可以在https://github.com/fchollet/deep-learning-models/releases下载。非常方便。

预测结果为:

Predicted: [[('n01871265', 'tusker', 0.41107917), ('n02504458', 'African_elephant', 0.39015812), ('n02504013', 'Indian_elephant', 0.12260196), ('n03000247', 'chain_mail', 0.023176488), ('n02437312', 'Arabian_camel', 0.020982226)]]

ResNet50模型的完整的结构图

以上就是python神经网络ResNet50模型的复现详解的详细内容,更多关于ResNet50模型复现的资料请关注脚本之家其它相关文章!

相关文章

  • Python+OpenCV进行不规则多边形ROI区域提取

    Python+OpenCV进行不规则多边形ROI区域提取

    ROI即感兴趣区域。机器视觉、图像处理中,从被处理的图像以方框、圆、椭圆、不规则多边形等方式勾勒出需要处理的区域,称为感兴趣区域,ROI。本文将利用Python和OpenCV实现不规则多边形ROI区域提取,需要的可以参考一下
    2022-03-03
  • python中字符串String及其常见操作指南(方法、函数)

    python中字符串String及其常见操作指南(方法、函数)

    String方法是用来处理代码中的字符串的,它几乎能搞定你所遇到的所有字符串格式,下面这篇文章主要给大家介绍了关于python中字符串String及其常见操作(方法、函数)的相关资料,需要的朋友可以参考下
    2022-04-04
  • pycharm终端解释器与Python解释器配置

    pycharm终端解释器与Python解释器配置

    这篇文章主要介绍了pycharm终端解释器与Python解释器配置,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-06-06
  • python实现斐波那契数列的方法示例

    python实现斐波那契数列的方法示例

    每个码农大概都会用自己擅长的语言写出一个斐波那契数列出来,斐波那契数列简单地说,起始两项为0和1,此后的项分别为它的前两项之后。下面这篇文章就给大家详细介绍了python实现斐波那契数列的方法,有需要的朋友们可以参考借鉴,下面来一起看看吧。
    2017-01-01
  • Python操作Excel文件的11种方法(全网最全)

    Python操作Excel文件的11种方法(全网最全)

    在日常工作或开发过程中,Excel文件作为一种常用的数据存储格式,其高效便捷的数据处理能力被广泛应用于数据统计、数据分析等领域,Python作为一种强大的编程语言,提供了丰富的库支持来实现对Excel文件的操作,本篇将详细介绍如何使用Python来操作Excel文件
    2025-03-03
  • Python+Turtle制作独特的表白图

    Python+Turtle制作独特的表白图

    这篇文章主要利用Python和Turtle库绘制独特的表白图,文中的示例代码讲解详细,对我们学习Python有一定的帮助,感兴趣的可以了解一下
    2022-04-04
  • Python结合requests和Cheerio处理网页内容的操作步骤

    Python结合requests和Cheerio处理网页内容的操作步骤

    Python因其简洁明了的语法和强大的库支持,成为了编写爬虫程序的首选语言之一,requests库是Python中用于发送HTTP请求的第三方库,而Cheerio库则是一个用于解析HTML和XML文档的库,本文给大家介绍了Python结合requests和Cheerio处理网页内容的操作步骤
    2025-01-01
  • 利用PyQt5+Matplotlib 绘制静态/动态图的实现代码

    利用PyQt5+Matplotlib 绘制静态/动态图的实现代码

    这篇文章主要介绍了利用PyQt5+Matplotlib 绘制静态/动态图的实现代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-07-07
  • python3报错check_hostname requires server_hostname的解决

    python3报错check_hostname requires server_hostname的解决

    这篇文章主要介绍了python3报错check_hostname requires server_hostname的解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-12-12
  • Python-openpyxl表格读取写入的案例详解

    Python-openpyxl表格读取写入的案例详解

    这篇文章主要介绍了Python-openpyxl表格读取写入的案例分析,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11

最新评论