python神经网络使用Keras进行模型的保存与读取

 更新时间:2022年05月05日 08:27:50   作者:Bubbliiiing  
这篇文章主要为大家介绍了python神经网络使用Keras进行模型的保存与读取,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

学习前言

开始做项目的话,有些时候会用到别人训练好的模型,这个时候要学会load噢。

Keras中保存与读取的重要函数

1、model.save

model.save用于保存模型,在保存模型前,首先要利用pip install安装h5py的模块,这个模块在Keras的模型保存与读取中常常被使用,用于定义保存格式。

pip install h5py

完成安装后,可以通过如下函数保存模型。

model.save("./model.hdf5")

其中,model是已经训练完成的模型,save函数传入的参数就是保存后的位置+名字。

2、load_model

load_model用于载入模型。

具体使用方式如下:

model = load_model("./model.hdf5")

其中,load_model函数传入的参数就是已经完成保存的模型的位置+名字。./表示保存在当前目录。

全部代码

这是一个简单的手写体识别例子,在之前也讲解过如何构建

python神经网络学习使用Keras进行简单分类,在最后我添加上了模型的保存与读取函数。

import numpy as np
from keras.models import Sequential,load_model,save_model
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化 
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
    Dense(32,input_dim = 784),
    Activation("relu"),
    Dense(10),
    Activation("softmax")
    ]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 100)
print("\nTest")
# 测试
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)
# 保存模型
model.save("./model.hdf5")
# 删除现有模型
del model
print("model had been del")
# 再次载入模型
model = load_model("./model.hdf5")
# 预测
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)

实验结果为:

Epoch 1/2
60000/60000 [==============================] - 6s 104us/step - loss: 0.4217 - acc: 0.8888
Epoch 2/2
60000/60000 [==============================] - 6s 99us/step - loss: 0.2240 - acc: 0.9366
Test
10000/10000 [==============================] - 1s 149us/step
accuracy: 0.9419
model had been del
10000/10000 [==============================] - 1s 117us/step
accuracy: 0.9419

以上就是python神经网络使用Keras进行模型的保存与读取的详细内容,更多关于Keras模型保存读取的资料请关注脚本之家其它相关文章!

相关文章

  • Python实现爬取网页中动态加载的数据

    Python实现爬取网页中动态加载的数据

    这篇文章主要介绍了Python实现爬取网页中动态加载的数据,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-08-08
  • Python中BeautifulSoup通过查找Id获取元素信息

    Python中BeautifulSoup通过查找Id获取元素信息

    这篇文章主要介绍了Python中BeautifulSoup通过查找Id获取元素信息,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-12-12
  • 使用Python实现一个简单的文件搜索引擎

    使用Python实现一个简单的文件搜索引擎

    这篇文章主要为大家详细介绍了Python中文件操作的基础和进阶知识并基于以上知识实现了一个简单的文件搜索引擎,感兴趣的小伙伴可以参考一下
    2024-05-05
  • PyTorch里面的torch.nn.Parameter()详解

    PyTorch里面的torch.nn.Parameter()详解

    今天小编就为大家分享一篇PyTorch里面的torch.nn.Parameter()详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 使用Python高效获取网络数据的操作指南

    使用Python高效获取网络数据的操作指南

    网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将详细介绍如何使用Python进行网络爬虫开发,包括基本概念、常用库、数据提取方法、反爬措施应对以及实际案例
    2025-03-03
  • python中matplotlib实现随鼠标滑动自动标注代码

    python中matplotlib实现随鼠标滑动自动标注代码

    这篇文章主要介绍了python中matplotlib实现随鼠标滑动自动标注代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04
  • Python 创建TCP服务器的方法

    Python 创建TCP服务器的方法

    这篇文章主要介绍了Python 创建TCP服务器的方法,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下
    2020-07-07
  • Django haystack实现全文搜索代码示例

    Django haystack实现全文搜索代码示例

    这篇文章主要介绍了Django haystack实现全文搜索代码示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-11-11
  • Python any()函数的使用方法

    Python any()函数的使用方法

    这篇文章主要介绍了Python any()函数的使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-10-10
  • Tensorflow2.1实现Fashion图像分类示例详解

    Tensorflow2.1实现Fashion图像分类示例详解

    这篇文章主要为大家介绍了Tensorflow2.1实现Fashion图像分类示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2022-11-11

最新评论