Tensorflow实现部分参数梯度更新操作

 更新时间:2020年01月23日 18:01:25   作者:zchenack  
今天小编就为大家分享一篇Tensorflow实现部分参数梯度更新操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

2. 使用tf.stop_gradient()函数

在建立Graph过程中使用该函数,非常简洁地避免了使用scope获取参数

3. 一个矩阵中部分行或列参数更新

如果一个矩阵,只有部分行或列需要更新参数,其它保持不变,该场景很常见,例如word embedding中,一些预定义的领域相关词保持不变(使用领域相关word embedding初始化),而另一些通用词变化。

import tensorflow as tf
import numpy as np
 
def entry_stop_gradients(target, mask):
  mask_h = tf.abs(mask-1)
  return tf.stop_gradient(mask_h * target) + mask * target
 
mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)
 
emb = tf.constant(np.ones([10, 5]))
 
matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
 
parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)
 
loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
  print sess.run(loss)
  print sess.run([grad1, grad2])

以上这篇Tensorflow实现部分参数梯度更新操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 在PyCharm中三步完成PyPy解释器的配置的方法

    在PyCharm中三步完成PyPy解释器的配置的方法

    今天小编就为大家分享一篇在PyCharm中三步完成PyPy解释器的配置的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python提示[Errno 32]Broken pipe导致线程crash错误解决方法

    Python提示[Errno 32]Broken pipe导致线程crash错误解决方法

    这篇文章主要介绍了Python提示[Errno 32]Broken pipe导致线程crash错误解决方法,是ThreadingHTTPServer实现http服务中经常会遇到的问题,需要的朋友可以参考下
    2014-11-11
  • Python利用treap实现双索引的方法

    Python利用treap实现双索引的方法

    所遍历的元素一定是递增(小堆)或是递减(大堆)关系,但是我们无法得知左子树与右子树两部分节点的排序关系。本文就来讲讲算法和数据结构共同满足一组特性,感兴趣的小伙伴请参考下面文章的内容
    2021-09-09
  • python不到50行代码完成了多张excel合并的实现示例

    python不到50行代码完成了多张excel合并的实现示例

    这篇文章主要介绍了python不到50行代码完成了多张excel合并的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-05-05
  • numpy数组拼接简单示例

    numpy数组拼接简单示例

    这篇文章主要介绍了numpy数组拼接简单示例,涉及对numpy数组的介绍,numpy数组的属性等内容,具有一定借鉴价值,需要的朋友可以参考下。
    2017-12-12
  • 使用Python操作ArangoDB的方法步骤

    使用Python操作ArangoDB的方法步骤

    这篇文章主要介绍了使用Python操作ArangoDB的方法步骤,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-02-02
  • Python-Seaborn热图绘制的实现方法

    Python-Seaborn热图绘制的实现方法

    这篇文章主要介绍了Python-Seaborn热图绘制的实现方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • 深入解析Python编程中super关键字的用法

    深入解析Python编程中super关键字的用法

    Python的子类调用父类成员时可以用到super关键字,初始化时需要注意super()和__init__()的区别,下面我们就来深入解析Python编程中super关键字的用法:
    2016-06-06
  • python中的字符串内部换行方法

    python中的字符串内部换行方法

    今天小编就为大家分享一篇python中的字符串内部换行方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-07-07
  • Python3中的bytes和str类型详解

    Python3中的bytes和str类型详解

    这篇文章主要介绍了Python3中的bytes和str类型,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-05-05

最新评论