python使用tensorflow保存、加载和使用模型的方法

 更新时间:2018年01月31日 16:22:10   作者:LordofRobots  
本篇文章主要介绍了python使用tensorflow保存、加载和使用模型的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False) 

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0') 

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.] 

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120. 

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run() 

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • Python3.9用pip安装wordcloud库失败的解决过程

    Python3.9用pip安装wordcloud库失败的解决过程

    一般在命令行输入pip install wordcloud 总会显示安装失败,所以下面这篇文章主要给大家介绍了关于Python3.9用pip安装wordcloud库失败的解决过程,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • Python编程中的反模式实例分析

    Python编程中的反模式实例分析

    这篇文章主要介绍了Python编程中的反模式,详细讲述了反模式的害处并以实例形式具体分析了容易造成的易错点,对于Python学习来说具有一定的参考借鉴价值,需要的朋友可以参考下
    2014-12-12
  • django+tornado实现实时查看远程日志的方法

    django+tornado实现实时查看远程日志的方法

    今天小编就为大家分享一篇django+tornado实现实时查看远程日志的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • 探索Python数据可视化库中Plotly Express的使用方法

    探索Python数据可视化库中Plotly Express的使用方法

    在数据分析和可视化领域,数据的有效呈现是至关重要的,python作为一种强大的编程语言,提供了多种数据可视化工具和库,本文将介绍Plotly Express的基本概念和使用方法,帮助读者快速入门并掌握数据可视化的技巧
    2023-06-06
  • python如何控制进程或者线程的个数

    python如何控制进程或者线程的个数

    这篇文章主要介绍了python如何控制进程或者线程的个数,帮助大家更好的理解和使用python,感兴趣的朋友可以了解下
    2020-10-10
  • Scrapy基于selenium结合爬取淘宝的实例讲解

    Scrapy基于selenium结合爬取淘宝的实例讲解

    今天小编就为大家分享一篇Scrapy基于selenium结合爬取淘宝的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • python获取文件真实链接的方法,针对于302返回码

    python获取文件真实链接的方法,针对于302返回码

    今天小编就为大家分享一篇python获取文件真实链接的方法,针对于302返回码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-05-05
  • 用Python的线程来解决生产者消费问题的示例

    用Python的线程来解决生产者消费问题的示例

    这篇文章主要介绍了用Python的线程来解决生产者消费问题的示例,包括对使用线程中容易出现的一些问题给出了相关解答,需要的朋友可以参考下
    2015-04-04
  • Python控制浏览器自动下载歌词评论并生成词云图

    Python控制浏览器自动下载歌词评论并生成词云图

    本文主要介绍了如何利用Python控制浏览器自动把歌词评论下载下来,并做成好看的词云图。文中的示例代码讲解详细,感兴趣的可以试一试
    2022-01-01
  • 浅谈python浮点数比较的三种方法

    浅谈python浮点数比较的三种方法

    在 Python 中,由于浮点数在计算机内部的表示方式是二进制的,因此进行浮点数比较时可能会出现精度问题,本文就介绍了三种解决方法,具有一定的参考价值,感兴趣的可以了解一下
    2023-09-09

最新评论