keras导入weights方式

 更新时间:2020年06月12日 14:47:22   作者:wangyin_2014  
这篇文章主要介绍了keras导入weights方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • python实现的阳历转阴历(农历)算法

    python实现的阳历转阴历(农历)算法

    这篇文章主要介绍了python实现的阳历转阴历(农历)算法,需要的朋友可以参考下
    2014-04-04
  • Python seaborn barplot画图案例

    Python seaborn barplot画图案例

    这篇文章主要介绍了Python seaborn barplot画图案例,文章围绕主题展开详细的内容介绍,具有一定的参考价值,需要的小伙伴可以参考一下
    2022-07-07
  • Python中Threading用法详解

    Python中Threading用法详解

    本篇文章给大家介绍了Python中Threading的详细用法,需要的朋友跟着小编一起学习下吧。
    2017-12-12
  • python能在浏览器能运行吗

    python能在浏览器能运行吗

    在本篇文章里小编给大家整理了关于python能否在浏览器能运行的相关知识点内容,有需要的朋友们可以学习下。
    2020-06-06
  • 使用python实现自动化控制电脑版微信

    使用python实现自动化控制电脑版微信

    这篇文章主要为大家详细介绍了如何通过Python去调用Windows API实现模拟人工操作的方式去实现控制微信电脑版,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-10-10
  • 详解django中Template语言

    详解django中Template语言

    Django是一个开放源代码的Web应用框架,由Python写成。这篇文章给大家介绍django中Template语言,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友参考下吧
    2020-02-02
  • 对Tensorflow中tensorboard日志的生成与显示详解

    对Tensorflow中tensorboard日志的生成与显示详解

    今天小编就为大家分享一篇对Tensorflow中tensorboard日志的生成与显示详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-02-02
  • python异常处理并调试

    python异常处理并调试

    这篇文章主要介绍了python异常处理并调试,异常是错误出现时,可以在正常的控制流程之外采取的行为下面我们就来看看python的那些异常,需要的小伙伴可以参考一下
    2022-02-02
  • 使用OpenCV circle函数图像上画圆的示例代码

    使用OpenCV circle函数图像上画圆的示例代码

    这篇文章主要介绍了使用OpenCV circle函数图像上画圆的示例代码,本文内容简短,给大家突出重点内容,需要的朋友可以参考下
    2019-12-12
  • 跟老齐学Python之永远强大的函数

    跟老齐学Python之永远强大的函数

    Python程序中的语句都会组织成函数的形式。通俗地说,函数就是完成特定功能的一个语句组,这组语句可以作为一个单位使用,并且给它取一个名字,这样,我们就可以通过函数名在程序的不同地方多次执行(这通常叫做函数调用),却不需要在所有地方都重复编写这些语句。
    2014-09-09

最新评论