tensorflow 加载部分变量的实例讲解

 更新时间:2018年07月27日 11:42:47   作者:imperfect00  
今天小编就为大家分享一篇tensorflow 加载部分变量的实例讲解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 init_op = tf.global_variables_initializer()
 sess.run(init_op)
 saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:

import tensorflow as tf
 
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
 # init_op = tf.global_variables_initializer()
 # sess.run(init_op)
 saver1.restore(sess, "checkpoint/model_test-1")
 saver2.restore(sess, "checkpoint/model_test-1")
 # saver.save(sess,"checkpoint/model_test",global_step=1)

以上这篇tensorflow 加载部分变量的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • HTML的form表单和django的form表单

    HTML的form表单和django的form表单

    这篇文章主要介绍了HTML的form表单和django的form表单,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • Python聚类算法之DBSACN实例分析

    Python聚类算法之DBSACN实例分析

    这篇文章主要介绍了Python聚类算法之DBSACN,结合实例形式详细分析了DBSACN算法的原理与具体实现技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-11-11
  • python机器学习sklearn实现识别数字

    python机器学习sklearn实现识别数字

    本文主要介绍了python机器学习sklearn实现识别数字,主要简述如何通过sklearn模块来进行预测和学习,最后再以图表这种更加直观的方式展现出来,感兴趣的可以了解一下
    2022-03-03
  • python里使用正则的findall函数的实例详解

    python里使用正则的findall函数的实例详解

    这篇文章主要介绍了python里使用正则的findall函数的实例详解的相关资料,希望通过本文能帮助到大家,需要的朋友可以参考下
    2017-10-10
  • Pyinstaller 打包exe教程及问题解决

    Pyinstaller 打包exe教程及问题解决

    这篇文章主要介绍了Pyinstaller 打包exe教程及问题解决,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-08-08
  • 如何使Python中的print()语句运行结果不换行

    如何使Python中的print()语句运行结果不换行

    这篇文章主要介绍了如何使Python中的print()显示当前语句后不换行,print() 是一个常用函数,但是每次,print()语句显示后都会换行,本问我们就来节日如何使print()显示当前语句后不换行,需要的朋友可以参考一下
    2022-03-03
  • django Admin文档生成器使用详解

    django Admin文档生成器使用详解

    这篇文章主要介绍了django Admin文档生成器,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-07-07
  • pandas DataFrame数据转为list的方法

    pandas DataFrame数据转为list的方法

    下面小编就为大家分享一篇pandas DataFrame数据转为list的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-04-04
  • 如何利用Python+OpenCV实现简易图像边缘轮廓检测(零基础)

    如何利用Python+OpenCV实现简易图像边缘轮廓检测(零基础)

    轮廓是形状分析和物体检测和识别的有用工具,下面这篇文章主要给大家介绍了关于如何利用Python+OpenCV实现简易图像边缘轮廓检测(零基础)的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-05-05
  • 使用Python项目生成所有依赖包的清单方式

    使用Python项目生成所有依赖包的清单方式

    这篇文章主要介绍了使用Python项目生成所有依赖包的清单方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-07-07

最新评论