python神经网络Keras构建CNN网络训练

 更新时间:2022年05月04日 12:29:46   作者:Bubbliiiing  
这篇文章主要为大家介绍了python神经网络学习使用Keras构建CNN网络训练,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

利用Keras构建完普通BP神经网络后,还要会构建CNN

Keras中构建CNN的重要函数

1、Conv2D

Conv2D用于在CNN中构建卷积层,在使用它之前需要在库函数处import它。

from keras.layers import Conv2D

在实际使用时,需要用到几个参数。

Conv2D(
    nb_filter = 32,
    nb_row = 5,
    nb_col = 5,
    border_mode = 'same',
    input_shape = (28,28,1)
)

其中,nb_filter代表卷积层的输出有多少个channel,卷积之后图像会越来越厚,这就是卷积后图像的厚度。nb_row和nb_col的组合就是卷积器的大小,这里卷积器是(5,5)的大小。border_mode代表着padding的方式,same表示卷积前后图像的shape不变。input_shape代表输入的shape。

2、MaxPooling2D

MaxPooling2D指的是池化层,在使用它之前需要在库函数处import它。

from keras.layers import MaxPooling2D

在实际使用时,需要用到几个参数。

MaxPooling2D(
    pool_size = (2,2),
    strides = (2,2),
    border_mode = 'same'
)

其中,pool_size表示池化器的大小,在这里,池化器的shape是(2,2)。strides是池化器的步长,这里在X和Y方向上都是2,池化后,输出比输入的shape小了1/2。border_mode代表着padding的方式。

3、Flatten

Flatten用于将卷积池化后最后的输出变为一维向量,这样才可以和全连接层连接,用于计算。在使用前需要用import导入。

from keras.layers import Flatten

在实际使用时,在最后一个池化层后直接添加层即可

model.add(Flatten())

全部代码

这是一个卷积神经网络的例子,用于识别手写体,其神经网络结构如下:

卷积层1->池化层1->卷积层2->池化层2->flatten->全连接层1->全连接层2->全连接层3。

单个样本的shape如下:

(28,28,1)->(28,28,32)->(14,14,32)->(14,14,64)->(7,7,64)->(3136)->(1024)->(256)

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation,Conv2D,MaxPooling2D,Flatten ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28,1)
X_test = X_test.reshape(-1,28,28,1)
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
    Conv2D(
        nb_filter = 32,
        nb_row = 5,
        nb_col = 5,
        border_mode = 'same',
        input_shape = (28,28,1)
    )
)
model.add(Activation("relu"))
# pool1
model.add(
    MaxPooling2D(
        pool_size = (2,2),
        strides = (2,2),
        border_mode = 'same'
    )
)
# conv2
model.add(
    Conv2D(
        nb_filter = 64,
        nb_row = 5,
        nb_col = 5,
        border_mode = 'same'
    )
)
model.add(Activation("relu"))
# pool2
model.add(
    MaxPooling2D(
        pool_size = (2,2),
        strides = (2,2),
        border_mode = 'same'
    )
)
# 全连接层
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(Dense(256))
model.add(Activation("relu"))
model.add(Dense(10))
model.add(Activation("softmax"))
adam = Adam(lr = 1e-4)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
## acc
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)

实验结果为:

Epoch 1/2
60000/60000 [==============================] - 64s 1ms/step - loss: 0.7664 - acc: 0.9224
Epoch 2/2
60000/60000 [==============================] - 62s 1ms/step - loss: 0.0473 - acc: 0.9858
Test
10000/10000 [==============================] - 2s 169us/step
accuracy: 0.9856

以上就是python神经网络Keras构建CNN网络训练的详细内容,更多关于Keras构建CNN网络训练的资料请关注脚本之家其它相关文章!

相关文章

  • python出现RuntimeError错误问题及解决

    python出现RuntimeError错误问题及解决

    这篇文章主要介绍了python出现RuntimeError错误问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-05-05
  • pycharm中dgl安装报错FileNotFoundError:Could not find module ‘E:\XXXX\XXXX\lib\site-packages\dgl\dgl.dl

    pycharm中dgl安装报错FileNotFoundError:Could not find&nb

    这篇文章主要介绍了pycharm中dgl安装报错FileNotFoundError:Could not find module ‘E:\XXXX\XXXX\lib\site-packages\dgl\dgl.dl问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2024-02-02
  • Python进阶之迭代器与迭代器切片教程

    Python进阶之迭代器与迭代器切片教程

    迭代器是 Python 中独特的一种高级特性,而切片也是一种高级特性,两者相结合,会产生什么样的结果呢,需要的朋友可以参考下
    2020-01-01
  • 在Python下进行UDP网络编程的教程

    在Python下进行UDP网络编程的教程

    这篇文章主要介绍了在Python下进行UDP网络编程的教程,UDP编程是Python网络编程部分的基础知识,示例代码基于Python2.x版本,需要的朋友可以参考下
    2015-04-04
  • Python多线程中线程数量如何控制

    Python多线程中线程数量如何控制

    本文主要介绍了Python多线程中线程数量如何控制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-01-01
  • Python 剪绳子的多种思路实现(动态规划和贪心)

    Python 剪绳子的多种思路实现(动态规划和贪心)

    这篇文章主要介绍了Python 剪绳子的多种思路实现(动态规划和贪心),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Python深入06——python的内存管理详解

    Python深入06——python的内存管理详解

    本篇文章主要介绍了python的内存管理详解,语言的内存管理是语言设计的一个重要方面。它是决定语言性能的重要因素。有兴趣的同学可以了解一下。
    2016-12-12
  • 一文搞懂关于 sys.argv 的详解

    一文搞懂关于 sys.argv 的详解

    sys.argv 其实就是一个列表,里边需要用户传入的参数,关键就是要明白这参数是从程序外部输入的,而非代码本身的什么地方,要想看到它的效果就应该将程序保存了,从外部来运行程序并给出参数,通过本文学习你将明白 sys.argv很多知识,感兴趣的朋友一起看看吧
    2023-01-01
  • python 监控服务器是否有人远程登录(详细思路+代码)

    python 监控服务器是否有人远程登录(详细思路+代码)

    这篇文章主要介绍了python 监控服务器是否有人远程登录的方法,帮助大家利用python 监控服务器,感兴趣的朋友可以了解下
    2020-12-12
  • python的reverse函数翻转结果为None的问题

    python的reverse函数翻转结果为None的问题

    这篇文章主要介绍了python的reverse函数翻转结果为None的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05

最新评论