PyTorch 如何检查模型梯度是否可导

 更新时间:2021年06月05日 11:44:43   作者:烟雨风渡  
这篇文章主要介绍了PyTorch 检查模型梯度是否可导的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

在这里插入图片描述

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

在这里插入图片描述 在这里插入图片描述

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 解决python多行注释引发缩进错误的问题

    解决python多行注释引发缩进错误的问题

    今天小编就为大家分享一篇解决python多行注释引发缩进错误的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • Python语言编写智力问答小游戏功能

    Python语言编写智力问答小游戏功能

    这篇文章主要介绍了使用Python代码语言简单编写一个轻松益智的小游戏,代码简单易懂,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-10-10
  • python中列表元素连接方法join用法实例

    python中列表元素连接方法join用法实例

    这篇文章主要介绍了python中列表元素连接方法join用法,实例分析了Python中join方法的使用技巧,非常具有实用价值,需要的朋友可以参考下
    2015-04-04
  • Python绘图系统之散点图和条形图的实现代码

    Python绘图系统之散点图和条形图的实现代码

    这篇文章主要为大家详细介绍了如何使用Python绘制散点图和条形图,文中的示例代码讲解详细,对我们的学习或工作有一定的帮助,感兴趣的可以了解一下
    2023-08-08
  • 用sqlalchemy构建Django连接池的实例

    用sqlalchemy构建Django连接池的实例

    今天小编就为大家分享一篇用sqlalchemy构建Django连接池的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • NoSql数据库介绍及使用Python连接MongoDB

    NoSql数据库介绍及使用Python连接MongoDB

    MongoDB是一个非常流行的NoSQL数据库,常用于大规模数据存储应用,下面这篇文章主要给大家介绍了关于NoSql数据库及使用Python连接MongoDB的相关资料,需要的朋友可以参考下
    2023-06-06
  • Python从ZabbixAPI获取信息及实现Zabbix-API 监控的方法

    Python从ZabbixAPI获取信息及实现Zabbix-API 监控的方法

    这篇文章主要介绍了Python从ZabbixAPI获取信息及实现Zabbix-API 监控的方法,需要的朋友可以参考下
    2018-09-09
  • Python基于stuck实现scoket文件传输

    Python基于stuck实现scoket文件传输

    这篇文章主要介绍了Python基于stuck实现scoket文件传输,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-04-04
  • python requests指定出口ip的例子

    python requests指定出口ip的例子

    今天小编就为大家分享一篇python requests指定出口ip的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-07-07
  • python 并发下载器实现方法示例

    python 并发下载器实现方法示例

    这篇文章主要介绍了python 并发下载器实现方法,结合实例形式详细分析了并发下载器相关原理及Python并发下载视频的相关操作技巧,需要的朋友可以参考下
    2019-11-11

最新评论