keras导入weights方式

 更新时间:2020年06月12日 14:47:22   作者:wangyin_2014  
这篇文章主要介绍了keras导入weights方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 如何将自己的python库打包成wheel文件并上传到pypi

    如何将自己的python库打包成wheel文件并上传到pypi

    这篇文章主要介绍了如何将自己的python库打包成wheel文件并上传到pypi,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-04-04
  • Matplotlib使用和绘制二维图表教程

    Matplotlib使用和绘制二维图表教程

    Matplotlib是一个强大的Python绘图库,可以用来绘制各种静态、动态和交互式的图表,文章介绍了Matplotlib的基本概念、绘制折线图、散点图、柱状图、直方图和饼图等方法,并详细解释了Matplotlib的三层结构
    2025-02-02
  • np.random.seed() 的使用详解

    np.random.seed() 的使用详解

    这篇文章主要介绍了np.random.seed() 的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-01-01
  • python中filter,map,reduce的作用

    python中filter,map,reduce的作用

    这篇文章主要介绍了python中filter,map,reduce的作用,文章首先通过map函数展开,map主要作用是计算一个序列或者多个序列进行函数映射之后的值,感兴趣的朋友可以参考一下
    2022-06-06
  • python实现把二维列表变为一维列表的方法分析

    python实现把二维列表变为一维列表的方法分析

    这篇文章主要介绍了python实现把二维列表变为一维列表的方法,结合实例形式总结分析了Python列表推导式、嵌套、循环等相关操作技巧,需要的朋友可以参考下
    2019-10-10
  • PyTorch 如何检查模型梯度是否可导

    PyTorch 如何检查模型梯度是否可导

    这篇文章主要介绍了PyTorch 检查模型梯度是否可导的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2021-06-06
  • Python+Pygame实战之俄罗斯方块游戏的实现

    Python+Pygame实战之俄罗斯方块游戏的实现

    俄罗斯方块,作为是一款家喻户晓的游戏,陪伴70、80甚至90后,度过无忧的儿时岁月,它上手简单能自由组合、拼接技巧也很多。本文就来用Python中的Pygame模块实现这一经典游戏,需要的可以参考一下
    2022-12-12
  • 深度剖析使用python抓取网页正文的源码

    深度剖析使用python抓取网页正文的源码

    平时打开一个网页,除了文章的正文内容,通常会有一大堆的导航,广告和其他方面的信息。本文的目的,在于说明如何从一个网页中提取出文章的正文内容,而过渡掉其他无关的的信息。
    2014-06-06
  • python 列表删除所有指定元素的方法

    python 列表删除所有指定元素的方法

    下面小编就为大家分享一篇python 列表删除所有指定元素的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 12个步骤教你理解Python装饰器

    12个步骤教你理解Python装饰器

    这篇文章主要介绍了12个步骤教你理解Python装饰器,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07

最新评论