tensorflow如何继续训练之前保存的模型实例

 更新时间:2020年01月21日 10:06:17   作者:by_side_with_sun  
今天小编就为大家分享一篇tensorflow如何继续训练之前保存的模型实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python生成随机数的方法

    Python生成随机数的方法

    这篇文章主要介绍了Python生成随机数的方法,有需要的朋友可以参考一下
    2014-01-01
  • Python技巧之实现批量统一图片格式和尺寸

    Python技巧之实现批量统一图片格式和尺寸

    大家在工作的时候基本都会接触到很多的图片,有时为了不同的工作需求需要修改图片的尺寸或者大小。本文为大家整理了Python批量转换图片格式和统一图片尺寸,希望对大家有所帮助
    2023-05-05
  • python设计tcp数据包协议类的例子

    python设计tcp数据包协议类的例子

    今天小编就为大家分享一篇python设计tcp数据包协议类的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • Python语言技巧之三元运算符使用介绍

    Python语言技巧之三元运算符使用介绍

    现在大部分高级语言都支持“?”这个三元运算符(ternary operator),它对应的表达式如下:condition ? value if true : value if false。很奇怪的是,这么常用的运算符python居然不支持
    2013-03-03
  • Python处理unicode字符的方法详解

    Python处理unicode字符的方法详解

    这篇文章主要介绍了Python处理unicode字符的方法详解,unicodedata中定义了所有Unicode字符的字符属性,主要包含两个功能,其一是根据名字查找字符;其二是给定字符查找其对应的信息,需要的朋友可以参考下
    2023-08-08
  • Python类属性的延迟计算

    Python类属性的延迟计算

    这篇文章主要为大家详细介绍了Python类属性的延迟计算,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2016-10-10
  • Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例

    Python2.7基于笛卡尔积算法实现N个数组的排列组合运算示例

    这篇文章主要介绍了Python2.7基于笛卡尔积算法实现N个数组的排列组合运算,涉及Python笛卡尔积算法及排列组合操作相关实现技巧,需要的朋友可以参考下
    2017-11-11
  • Python pomegranate库实现基于贝叶斯网络拼写检查器

    Python pomegranate库实现基于贝叶斯网络拼写检查器

    这篇文章主要为大家介绍了Python pomegranate库实现基于贝叶斯网络拼写检查器示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪<BR>
    2023-04-04
  • Pycharm plot独立窗口显示的操作

    Pycharm plot独立窗口显示的操作

    这篇文章主要介绍了Pycharm plot独立窗口显示的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-12-12
  • python之yield和Generator深入解析

    python之yield和Generator深入解析

    这篇文章主要介绍了python之yield和Generator深入解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-09-09

最新评论