tensorflow saver 保存和恢复指定 tensor的实例讲解

 更新时间:2018年07月26日 09:29:50   作者:血影雪梦  
今天小编就为大家分享一篇tensorflow saver 保存和恢复指定 tensor的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()     

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python GUI程序类写法与Label介绍

    Python GUI程序类写法与Label介绍

    这篇文章主要介绍了Python GUI程序类写法与Label介绍,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧
    2022-09-09
  • 详解Python中RegEx在数据处理中的应用

    详解Python中RegEx在数据处理中的应用

    正则表达式(Regular Expressions,简称 RegEx)是一种强大的文本匹配和搜索工具,它在数据处理、文本解析和字符串操作中发挥着关键作用,下面就跟随小编一起来了解一下RegEx的具体使用吧
    2024-01-01
  • python使用Matplotlib绘制多种常见图形

    python使用Matplotlib绘制多种常见图形

    这篇文章主要介绍了python使用Matplotlib绘制多种常见图形,文章围绕主题展开详细的用Matplotlib绘制内容,需要的小伙伴可以参考一下
    2022-05-05
  • python接口调用已训练好的caffe模型测试分类方法

    python接口调用已训练好的caffe模型测试分类方法

    今天小编就为大家分享一篇python接口调用已训练好的caffe模型测试分类方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python控制浏览器自动下载歌词评论并生成词云图

    Python控制浏览器自动下载歌词评论并生成词云图

    本文主要介绍了如何利用Python控制浏览器自动把歌词评论下载下来,并做成好看的词云图。文中的示例代码讲解详细,感兴趣的可以试一试
    2022-01-01
  • Django模板标签{% for %}循环,获取制定条数据实例

    Django模板标签{% for %}循环,获取制定条数据实例

    这篇文章主要介绍了Django模板标签{% for %}循环,获取制定条数据实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • 解决pandas read_csv 读取中文列标题文件报错的问题

    解决pandas read_csv 读取中文列标题文件报错的问题

    今天小编就为大家分享一篇解决pandas read_csv 读取中文列标题文件报错的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python实现简单多线程任务队列

    Python实现简单多线程任务队列

    本文给大家介绍的是使用很简单的代码实现的多线程任务队列,给大家一个思路,希望对大家学习python能够有所帮助
    2016-02-02
  • python 日期排序的实例代码

    python 日期排序的实例代码

    这篇文章主要介绍了python 日期排序的实例代码,代码简单易懂,非常不错,具有一定的参考借鉴价值 ,需要的朋友可以参考下
    2019-07-07
  • Python+Empyrical实现计算风险指标

    Python+Empyrical实现计算风险指标

    Empyrical 是一个知名的金融风险指标库。它能够用于计算年平均回报、最大回撤、Alpha值等。下面就教你如何使用 Empyrical 这个风险指标计算神器
    2022-05-05

最新评论