在pytorch中对非叶节点的变量计算梯度实例

 更新时间:2020年01月10日 15:39:20   作者:FesianXu  
今天小编就为大家分享一篇在pytorch中对非叶节点的变量计算梯度实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。

注册hook函数

Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上这篇在pytorch中对非叶节点的变量计算梯度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • VSCode设置Python语言自动格式化的详细方案

    VSCode设置Python语言自动格式化的详细方案

    VSCode Python自动格式化是指使用VSCode编辑器中的Python插件,可以自动对Python代码进行格式化,使其符合PEP 8规范,这篇文章主要给大家介绍了关于VSCode设置Python语言自动格式化的详细方案,需要的朋友可以参考下
    2023-07-07
  • python用列表生成式写嵌套循环的方法

    python用列表生成式写嵌套循环的方法

    今天小编就为大家分享一篇python用列表生成式写嵌套循环的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-11-11
  • python实现逆波兰计算表达式实例详解

    python实现逆波兰计算表达式实例详解

    这篇文章主要介绍了python实现逆波兰计算表达式的方法,较为详细的分析了逆波兰表达式的概念及实现技巧,具有一定参考借鉴价值,需要的朋友可以参考下
    2015-05-05
  • python单元测试框架unittest基本用法案例

    python单元测试框架unittest基本用法案例

    unittest库unittest库是python的内置库,用来对程序进行测试,下面这篇文章主要给大家介绍了关于python中单元测试框架unittest基本用法的相关资料,需要的朋友可以参考下
    2022-09-09
  • Python详细介绍模型封装部署流程

    Python详细介绍模型封装部署流程

    本文实例讲述了Python模型封装部署的原理与实现方法。封装即是隐藏对象的属性和实现细节,仅对外提供公共访问方式,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2022-08-08
  • Python列表append()函数使用方法详解

    Python列表append()函数使用方法详解

    python中的append()函数是在列表末尾添加新的对象,且将添加的对象最为一个整体,下面这篇文章主要给大家介绍了关于Python列表append()函数使用方法的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-06-06
  • Python相互导入的问题解决

    Python相互导入的问题解决

    大家好,本篇文章主要讲的是Python相互导入的问题解决,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
    2022-01-01
  • Python中的元组介绍

    Python中的元组介绍

    今天小编就为大家分享一篇关于Python中的元组介绍,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2019-01-01
  • python3 将阶乘改成函数形式进行调用的操作

    python3 将阶乘改成函数形式进行调用的操作

    这篇文章主要介绍了python3 将阶乘改成函数形式进行调用的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • 详解Python操作RabbitMQ服务器消息队列的远程结果返回

    详解Python操作RabbitMQ服务器消息队列的远程结果返回

    RabbitMQ是一款基于MQ的服务器,Python可以通过Pika库来进行程序操控,这里我们将来详解Python操作RabbitMQ服务器消息队列的远程结果返回:
    2016-06-06

最新评论