pytorch的backward()的底层实现逻辑详解

 更新时间:2023年11月20日 11:34:16   作者:子燕若水  
自动微分是一种计算张量(tensors)的梯度(gradients)的技术,它在深度学习中非常有用,这篇文章主要介绍了pytorch的backward()的底层实现逻辑,需要的朋友可以参考下

自动微分是一种计算张量(tensors)的梯度(gradients)的技术,它在深度学习中非常有用。自动微分的基本思想是:

  • 自动微分会记录数据(张量)和所有执行的操作(以及产生的新张量)在一个由函数(Function)对象组成的有向无环图(DAG)中。在这个图中,叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点追踪这个图,可以使用链式法则(chain rule)自动地计算梯度。
  • 在前向传播(forward pass)中,自动微分同时做两件事:
    • 运行请求的操作来计算一个结果张量,以及
    • 在 DAG 中保留操作的梯度函数。
    • 在 DAG 中保留操作的梯度函数,这就是说,当你给自动微分一个张量和一个操作,它不仅会计算出结果张量,还会记住这个操作的梯度函数,也就是这个操作对输入张量的导数。例如,如果你给自动微分一个张量 x = [1, 2, 3] 和一个操作 y = x + 1,它不仅会计算出 y = [2, 3, 4],还会记住这个操作的梯度函数是 dy/dx = 1,也就是说,y 对 x 的导数是 1。这样,当你需要计算梯度时,自动微分就可以根据这个梯度函数来计算出结果张量对输入张量的梯度。
  • 在PyTorch中,DAG是动态的。需要注意的一点是,图是从头开始重新创建的;在每个 .backward() 调用之后,autograd开始填充一个新的图。
  • 后向传播开始于当在 DAG 的根节点上调用 .backward() 方法。这个方法会触发自动微分开始计算梯度。
  • 自动微分会从每个 .grad_fn 中计算梯度,这个 .grad_fn 是一个函数对象,它保存了操作的梯度函数。例如,如果一个操作是 y = x + 1,那么它的 .grad_fn 就是 dy/dx = 1。
  • 自动微分会将计算出的梯度累加到相应张量的 .grad 属性中,这个 .grad 属性是一个张量,它保存了结果张量对输入张量的梯度。例如,如果一个结果张量是 y = [2, 3, 4],那么它的 .grad 属性就是 [1, 1, 1],表示 y 对 x 的梯度是 1。
  • 使用链式法则(chain rule),自动微分会一直向后传播,直到到达叶子张量。链式法则是一种数学公式,它可以将复合函数的梯度分解为简单函数的梯度的乘积。例如,如果一个复合函数是 z = f(g(x)),那么它的梯度是 dz/dx = dz/dg * dg/dx。
 
import torch
import torch.nn as nn
M = nn.Linear(2, 2) # neural network module
M.eval() # set M to evaluation mode
with torch.no_grad(): # disable gradient computation
    for param in M.parameters(): # loop over all parameters
        param.fill_(1) # fill the parameter with 1
M.requires_grad_(False)
a = torch.tensor([1., 2.], requires_grad=True) # leaf node
b = torch.tensor([13., 32.], requires_grad=True) # leaf node
c = M(a) # non-leaf node
c2 = M(b) # non-leaf node
d = c * 2  # non-leaf node
d.sum().backward() # compute gradients
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(M.weight.grad) # None

构建计算图:当我们调用backward()方法时,PyTorch会自动构建从叶子节点a到损失值d.sum()的计算图,这是一个有向无环图,表示了各个张量之间的运算关系。计算图中还包含了两个中间变量c和d,它们是由a经过M模型的前向传播得到的。计算图的作用是记录反向传播的路径,以便于计算梯度。 计算梯度:在计算图中,每个张量都有一个属性grad,用于存储它的梯度值。当我们调用backward()方法时,PyTorch会沿着计算图按照链式法则计算并填充每个张量的grad属性。由于我们只对叶子节点a的梯度感兴趣,所以只有a的grad属性会被计算出来,而中间变量c和d的grad属性会被忽略。a的grad属性的值是损失值d.sum()对a的偏导数,表示了a的变化对损失值的影响。 

到此这篇关于pytorch的backward()的底层实现逻辑的文章就介绍到这了,更多相关pytorch backward()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • OFD格式文件及如何适应Python将PDF转换为OFD格式文件

    OFD格式文件及如何适应Python将PDF转换为OFD格式文件

    OFD是中国自主研发的一种固定版式文档格式,主要用于电子公文、档案管理等领域,这篇文章主要介绍了OFD格式文件及如何适应Python将PDF转换为OFD格式文件的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下
    2025-11-11
  • Python使用xlrd模块实现操作Excel读写的方法汇总

    Python使用xlrd模块实现操作Excel读写的方法汇总

    本文介绍Python中使用xlrd、xlwt、xlutils模块操作Excel文件的方法,xlrd用于读取Excel文件,但2.0.0版本后不支持.xlsx格式,xlwt用于创建和写入Excel文件,而xlutils主要用于复制和处理Excel文件,详细介绍了如何打开文件、获取工作表信息、操作行列数据和处理日期格式数据
    2024-10-10
  • 详解如何修改jupyter notebook的默认目录和默认浏览器

    详解如何修改jupyter notebook的默认目录和默认浏览器

    这篇文章主要介绍了详解如何修改jupyter notebook的默认目录和默认浏览器,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-01-01
  • python实现图像识别的示例代码

    python实现图像识别的示例代码

    这篇文章主要介绍了python实现图像识别的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2021-03-03
  • Python中selenium获取token的方法

    Python中selenium获取token的方法

    本文主要介绍了Python中selenium获取token的方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python使用Matlab命令过程解析

    Python使用Matlab命令过程解析

    这篇文章主要介绍了Python使用Matlab命令过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • 详解Python字符串对象的实现

    详解Python字符串对象的实现

    本文介绍了 python 内部是如何管理字符串对象,以及字符串查找操作是如何实现的,感兴趣的小伙伴们可以参考一下
    2015-12-12
  • 使用Keras实现简单线性回归模型操作

    使用Keras实现简单线性回归模型操作

    这篇文章主要介绍了使用Keras实现简单线性回归模型操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 用python制作词云视频详解

    用python制作词云视频详解

    这篇文章主要介绍了用python制作词云视频详解,原理解释清晰,代码详细,用于练习很适合,需要的朋友可以参考下
    2021-04-04
  • Pycharm连接远程服务器并远程调试的全过程

    Pycharm连接远程服务器并远程调试的全过程

    PyCharm 是 JetBrains 开发的一款 Python 跨平台编辑器,下面这篇文章主要介绍了Pycharm连接远程服务器并远程调试的全过程,文中通过图文介绍的非常详细,需要的朋友可以参考下
    2021-06-06

最新评论