pytorch 如何打印网络回传梯度

 更新时间:2021年05月13日 11:52:22   作者:Jee_King  
这篇文章主要介绍了pytorch 实现打印网络回传梯度的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

需求:

打印梯度,检查网络学习情况

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)
 ...

打印结果如下:

name表示网络参数的名字; parms.requires_grad 表示该参数是否可学习,是不是frozen的; parm.grad 打印该参数的梯度值。

补充:pytorch的梯度计算

看代码吧~

import torch
from torch.autograd import Variable
x = torch.Tensor([[1.,2.,3.],[4.,5.,6.]])  #grad_fn是None
x = Variable(x, requires_grad=True)
y = x + 2
z = y*y*3
out = z.mean()
#x->y->z->out
print(x)
print(y)
print(z)
print(out)
#结果:
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[3., 4., 5.],
        [6., 7., 8.]], grad_fn=<AddBackward>)
tensor([[ 27.,  48.,  75.],
        [108., 147., 192.]], grad_fn=<MulBackward>)
tensor(99.5000, grad_fn=<MeanBackward1>)

若是关于graph leaves求导的结果变量是一个标量,那么gradient默认为None,或者指定为“torch.Tensor([1.0])”

若是关于graph leaves求导的结果变量是一个向量,那么gradient是不能缺省的,要是和该向量同纬度的tensor

out.backward()
print(x.grad)
#结果:
tensor([[3., 4., 5.],
        [6., 7., 8.]])
#如果是z关于x求导就必须指定gradient参数:
gradients = torch.Tensor([[2.,1.,1.],[1.,1.,1.]])
z.backward(gradient=gradients)
#若z不是一个标量,那么就先构造一个标量的值:L = torch.sum(z*gradient),再关于L对各个leaf Variable计算梯度
#对x关于L求梯度
x.grad
#结果:
tensor([[36., 24., 30.],
        [36., 42., 48.]])

错误情况

z.backward()
print(x.grad) 
#报错:RuntimeError: grad can be implicitly created only for scalar outputs只能为标量创建隐式变量
    
x1 = Variable(torch.Tensor([[1.,2.,3.],[4.,5.,6.]])) 
x2 = Variable(torch.arange(4).view(2,2).type(torch.float), requires_grad=True)
c = x2.mm(x1)
c.backward(torch.ones_like(c))
# c.backward()
#RuntimeError: grad can be implicitly created only for scalar outputs
print(x2.grad)

从上面的例子中,out是常量,可以默认创建隐变量,如果反向传播的不是常量,要知道该矩阵的具体值,在网络中就是loss矩阵,方向传播的过程中就是拿该归一化的损失乘梯度来更新各神经元的参数。

看到一个博客这样说:loss = criterion(outputs, labels)对应loss += (label[k] - h) * (label[k] - h) / 2

就是求loss(其实我觉得这一步不用也可以,反向传播时用不到loss值,只是为了让我们知道当前的loss是多少)

我认为一定是要求loss的具体值,才能对比阈值进行分类,通过非线性激活函数,判断是否激活。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 趣味Python实战练习之自动更换桌面壁纸脚本附源码

    趣味Python实战练习之自动更换桌面壁纸脚本附源码

    读万卷书不如行万里路,学的扎不扎实要通过实战才能看出来,本篇文章手把手带你编写一个自动更换桌面壁纸的脚本,代码简洁而且短,相信你一定看得懂,大家可以在过程中查缺补漏,看看自己掌握程度怎么样
    2021-10-10
  • Python语法糖遍历列表时删除元素方法示例详解

    Python语法糖遍历列表时删除元素方法示例详解

    这篇文章主要为大家介绍了Python语法糖遍历列表时删除元素详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
    2023-05-05
  • python删除指定类型(或非指定)的文件实例详解

    python删除指定类型(或非指定)的文件实例详解

    这篇文章主要介绍了python删除指定类型(或非指定)的文件,以实例形式较为详细的分析了Python删除文件的相关技巧,需要的朋友可以参考下
    2015-07-07
  • Python之tkinter进度条Progressbar用法解读

    Python之tkinter进度条Progressbar用法解读

    这篇文章主要介绍了Python之tkinter进度条Progressbar用法解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2023-05-05
  • python如何进行基准测试

    python如何进行基准测试

    这篇文章主要介绍了python如何进行基准测试,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-04-04
  • python读取tif图片时保留其16bit的编码格式实例

    python读取tif图片时保留其16bit的编码格式实例

    今天小编就为大家分享一篇python读取tif图片时保留其16bit的编码格式实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • 基于Pycharm加载多个项目过程图解

    基于Pycharm加载多个项目过程图解

    这篇文章主要介绍了基于Pycharm加载多个项目过程图解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-01-01
  • 以一个投票程序的实例来讲解Python的Django框架使用

    以一个投票程序的实例来讲解Python的Django框架使用

    这篇文章主要介绍了以一个投票程序的实例来讲解Python的Django框架使用,Django是Python世界中人气最高的MVC框架,需要的朋友可以参考下
    2016-02-02
  • Python中的pprint折腾记

    Python中的pprint折腾记

    这篇文章主要介绍了Python中的pprint折腾记,本文着重讲解pprint的使用,并给出使用实例,需要的朋友可以参考下
    2015-01-01
  • 利用Python将社交网络进行可视化

    利用Python将社交网络进行可视化

    这篇文章介绍了利用Python将社交网络进行可视化,主要是一些Python的第三方库来进行社交网络的可视化,利用领英(Linkedin)的社交关系数据展开介绍,内容可当学习练习题有一定的参考价值,需要的小伙伴可以参考一下
    2022-06-06

最新评论