Keras保存模型并载入模型继续训练的实现

 更新时间:2021年02月20日 09:41:14   作者:凌逆战  
这篇文章主要介绍了Keras保存模型并载入模型继续训练的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

我们以MNIST手写数字识别为例

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
    Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
  ])
 
# 定义优化器
sgd = SGD(lr=0.2)
 
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
  optimizer = sgd,
  loss = 'mse',
  metrics=['accuracy'],
)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存模型
model.save('model.h5')  # HDF5文件,pip install h5py

载入初次训练的模型,再训练

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
 
# 载入模型
model = load_model('model.h5')
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
 
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
 
print('\ntest loss',loss)
print('accuracy',accuracy)
 
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
 
print(json_string)

关于compile和load_model()的使用顺序

这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?

compile做什么?

compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。

如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。

是否需要多次编译?

除非我们要更改其中之一:损失函数、优化器 / 学习率、度量

又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。

再次compile的后果?

如果再次编译模型,将会丢失优化器状态.

这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。

到此这篇关于Keras保存模型并载入模型继续训练的实现的文章就介绍到这了,更多相关Keras保存模型并加载模型内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 基于python脚本实现软件的注册功能(机器码+注册码机制)

    基于python脚本实现软件的注册功能(机器码+注册码机制)

    用户运行程序后,通过文件自动检测认证状态,如果未经认证,就需要注册。这篇文章主要介绍了基于python脚本实现软件的注册功能(机器码+注册码机制)的相关资料,需要的朋友可以参考下
    2016-10-10
  • Python中PyQt5/PySide2的按钮控件使用实例

    Python中PyQt5/PySide2的按钮控件使用实例

    这篇文章主要介绍了PyQt5/PySide2的按钮控件使用实例,代码简单易懂,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-08-08
  • 解决keras使用cov1D函数的输入问题

    解决keras使用cov1D函数的输入问题

    这篇文章主要介绍了解决keras使用cov1D函数的输入问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • Tensorflow实现神经网络拟合线性回归

    Tensorflow实现神经网络拟合线性回归

    这篇文章主要为大家详细介绍了Tensorflow实现神经网络拟合线性回归,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-07-07
  • Python+tkinter实现高清图片保存

    Python+tkinter实现高清图片保存

    作为爱玩电脑的你是不是也需要经常更换一下自己的电脑壁纸呢?但是在网上有很多心仪的图片想要保存下来,如果一张张的去保存那效率又低。所以本文用Python写一个保存图片的功能,把我们的图片给保存到我们的电脑,需要的可以参考一下
    2022-03-03
  • python中os包的用法

    python中os包的用法

    这篇文章主要介绍了python中os包的用法,文中给大家提到了python中os的常用方法,给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-06-06
  • pandas.DataFrame Series排序的使用(sort_values,sort_index)

    pandas.DataFrame Series排序的使用(sort_values,sort_index)

    本文主要介绍了pandas.DataFrame Series排序的使用(sort_values,sort_index),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python编程中NotImplementedError的使用方法

    Python编程中NotImplementedError的使用方法

    下面小编就为大家分享一篇Python编程中NotImplementedError的使用方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • python虚拟环境的安装和配置(virtualenv,virtualenvwrapper)

    python虚拟环境的安装和配置(virtualenv,virtualenvwrapper)

    这篇文章主要介绍了python虚拟环境的安装和配置(virtualenv,virtualenvwrapper),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-08-08
  • 如何将python代码打包成pip包(可以pip install)

    如何将python代码打包成pip包(可以pip install)

    这篇文章主要介绍了如何将python代码打包成pip包(可以pip install),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02

最新评论