TensorFlow的权值更新方法

 更新时间:2018年06月14日 09:37:04   作者:朂嘼  
今天小编就为大家分享一篇TensorFlow的权值更新方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

一. MovingAverage权值滑动平均更新

1.1 示例代码:

def create_target_q_network(self,state_dim,action_dim,net):
  state_input = tf.placeholder("float",[None,state_dim])
  action_input = tf.placeholder("float",[None,action_dim])

  ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
  target_update = ema.apply(net)
  target_net = [ema.average(x) for x in net]

  layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
  layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
  q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

  return state_input,action_input,q_value_output,target_update

def update_target(self):
  self.sess.run(self.target_update)
  

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值

tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None) 

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:

cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明

y=tf.stop_gradient(r+Q next r+Qnext)

以上这篇TensorFlow的权值更新方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • videocapture库制作python视频高速传输程序

    videocapture库制作python视频高速传输程序

    python视频高速传输程序,大家参考使用吧
    2013-12-12
  • Django发送html邮件的方法

    Django发送html邮件的方法

    这篇文章主要介绍了Django发送html邮件的方法,涉及Django框架操作邮件的相关技巧,需要的朋友可以参考下
    2015-05-05
  • Pytorch数据拼接与拆分操作实现图解

    Pytorch数据拼接与拆分操作实现图解

    这篇文章主要介绍了Pytorch数据拼接与拆分操作实现图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • 基于python实现cdn日志文件导入mysql进行分析

    基于python实现cdn日志文件导入mysql进行分析

    这篇文章主要介绍了基于python实现cdn日志文件导入mysql进行分析,本文以阿里云CDN日志作为辅助查询数据展开主题内容,其它云平台大同小异,需要的小伙伴可以参考一下
    2022-05-05
  • Python实现二分查找与bisect模块详解

    Python实现二分查找与bisect模块详解

    二分查找又叫折半查找,二分查找应该属于减治技术的成功应用。python标准库中还有一个灰常给力的模块,那就是bisect。这个库接受有序的序列,内部实现就是二分。下面这篇文章就详细介绍了Python如何实现二分查找与bisect模块,需要的朋友可以参考借鉴,下面来一起看看吧。
    2017-01-01
  • 基于Python实现身份证信息识别功能

    基于Python实现身份证信息识别功能

    身份证是用于证明个人身份和身份信息的官方证件,在现代社会中,身份证被广泛应用于各种场景,如就业、教育、医疗、金融等,它包含了个人的基本信息,本文给大家介绍了如何基于Python实现身份证信息识别功能,感兴趣的朋友可以参考下
    2024-01-01
  • python使用yield压平嵌套字典的超简单方法

    python使用yield压平嵌套字典的超简单方法

    这篇文章主要给大家介绍了关于python使用yield压平嵌套字典的超简单方法,文中通过示例代码介绍的非常详细,对大家的学习或者使用python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-11-11
  • python绘制堆叠条形图介绍

    python绘制堆叠条形图介绍

    大家好,本篇文章主要讲的是python绘制堆叠条形图介绍,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2021-12-12
  • Python学习之名字,作用域,名字空间

    Python学习之名字,作用域,名字空间

    这篇文章主要介绍了Python学习之名字,作用域,名字空间,文章围绕主题展开详细内容介绍,具有一定的参考价值,需要的小伙伴可以参考以一下
    2022-05-05
  • 使用IPython或Spyder将省略号表示的内容完整输出

    使用IPython或Spyder将省略号表示的内容完整输出

    这篇文章主要介绍了使用IPython或Spyder将省略号表示的内容完整输出,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-04-04

最新评论