Tensorflow之MNIST CNN实现并保存、加载模型

 更新时间:2020年06月17日 10:25:55   作者:uflswe  
这篇文章主要为大家详细介绍了Tensorflow之MNIST CNN实现并保存、加载模型,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 如何修改新版Python的pip默认安装路径

    如何修改新版Python的pip默认安装路径

    pip安装的第三方库默认存放在C盘中,为了便于管理和不过度占用C盘空间所以想修改默认的pip路径,这篇文章主要介绍了修改新版Python的pip默认安装路径的过程,需要的朋友可以参考下
    2024-03-03
  • Python Pandas中loc和iloc函数的基本用法示例

    Python Pandas中loc和iloc函数的基本用法示例

    无论是loc还是iloc都是pandas中数据筛选的函数,下面这篇文章主要给大家介绍了关于Python Pandas中loc和iloc函数的基本用法示例,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
    2022-07-07
  • python实现在线翻译

    python实现在线翻译

    这篇文章主要介绍了python实现在线翻译,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-06-06
  • Python 获得命令行参数的方法(推荐)

    Python 获得命令行参数的方法(推荐)

    本篇将介绍python中sys, getopt模块处理命令行参数的方法,本文给大家介绍的非常详细,具有参考借鉴价值,需要的朋友参考下吧
    2018-01-01
  • 详解Python3序列赋值、序列解包

    详解Python3序列赋值、序列解包

    这篇文章主要介绍了Python3序列赋值、序列解包的相关知识,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-05-05
  • 深入解读Python如何进行文件读写

    深入解读Python如何进行文件读写

    文件的作用 就是把一些存储存放起来,可以让程序下一次执行的时候直接使用,而不必重新制作一份,省时省力,本文将带你了解通过python如何进行文件的读写操作
    2021-10-10
  • Pytorch-mlu 实现添加逐层算子方法详解

    Pytorch-mlu 实现添加逐层算子方法详解

    本文主要分享了在寒武纪设备上 pytorch-mlu 中添加逐层算子的方法教程,代码具有一定学习价值,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-11-11
  • 使用pycallgraph分析python代码函数调用流程以及框架解析

    使用pycallgraph分析python代码函数调用流程以及框架解析

    这篇文章主要介绍了使用pycallgraph分析python代码函数调用流程以及框架解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-03-03
  • python:接口间数据传递与调用方法

    python:接口间数据传递与调用方法

    今天小编就为大家分享一篇python:接口间数据传递与调用方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python使用筛选法计算小于给定数字的所有素数

    python使用筛选法计算小于给定数字的所有素数

    这篇文章主要为大家详细介绍了python使用筛选法计算小于给定数字的所有素数,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2018-03-03

最新评论