pytorch如何自定义forward和backward函数

 更新时间:2024年10月12日 16:08:13   作者:xx_xjm  
PyTorch自动求导功能强大,但在特定情况下需要用户自行定义backward函数,通过实例解释了保存变量、计算梯度、链式法则等核心概念,并展示了如何通过自定义函数集成到网络中以及如何正确返回梯度,此外,还讨论了多输出情况下的梯度传递

pytorch自定义forward和backward函数

pytorch会自动求导,但是当遇到无法自动求导的时候,需要自己认为定义求导过程,这个时候就涉及到要定义自己的forward和backward函数。

举例如下:

看到这里,大家应该会有很多疑问

比如:

  • 1:ctx.save_for_backward和ctx.saved_tensors的含义
  • 2:backward中各个计算函数的意义,以及backward的输入参数grad_out是什么,以及grad_out包含哪些数据。

针对以上问题,我们一个个解答

  • 第一个问题:百度吧,答案很多!!!!
  • 第二个问题:拿上面这个例子来看,我们定义了一个类似于线性层的东西,但注意这不是线性层,因为我们是直接把输入和weight用*来做点对点的乘法的,所以这不是我们通常情况下的线性层。

但是这么看也费劲,我们写一个网络,把这个函数加到网络中去,再完整的跑一遍看吧!

测试代码

结果如下:

来进行解答

首先,backward函数的返回值,就是对应着forward里面的参数的梯度,也就是说,forward函数里面有几个输入参数,那么backward函数的输出就要有几个!为什么是这样?

我们首先要理解backward的输入grad_out,为什么backward的参数就是一个,因为这是根据链式法则来的

比如,我们定义三个函数H(对应上面网络中linear1),F(自定义函数xjm_inter),D(对应上面网络中linear2),定义一个输入x(对应上面输入a),定义一个输出y(对应上面输出b):

y = D(F(H(X)))

现在,我们求y对x的偏导,那么:

dy/dx = dy/dD * dD/dF * dF/dH * dH/dx

好吧看到这里你可能还是不懂,为什么backward的参数就是一个grad_out!!

我们韩式以上面则个函数为例子,但是,我们现在不求y对x的导数,我们假设F函数有一个叶子节点(或者说requires_grad=True)的参数w1,现在我们要求y对w1的导数:

所以

dy/dw1 = dy/dD *dD/dF * dF/dw1

那么此时,F就是我们上面代码中自定义的xjm_inter函数,则 grad_out = dy/dD *dD/dF。

怎么理解呢,根据链式法则,我们呢所定义的网络中的每一层都是一个单独的函数,所以函数中的变量的最终求导其实只取决于该函数本身,链式法则求导传递过来的其实永远都知识一个值,这就是为什么backward函数的输出只有一个。

扩展

当forward的输出有多个的时候,那么就有多个链式法则,因为可以同时对x或者对w求导,此时backward的输入可以是一个,也可以是对应forward输出的个数,如果是一个则是一个元组,包含对应的梯度!!!

那么我们的backward要实现什么样的功能呢?说到这里,大家应该大概能明白了,就是实现当前层那的梯度计算,并进行返回,所以,这也是为什么backward的返回值要和forward的输入值一一对应,否则会报错。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Python中有哪些关键字及关键字的用法

    Python中有哪些关键字及关键字的用法

    这篇文章主要介绍了Python中有哪些关键字及关键字的用法,分享python中常用的关键字,本文结合示例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-02-02
  • Python实现的FTP通信客户端与服务器端功能示例

    Python实现的FTP通信客户端与服务器端功能示例

    这篇文章主要介绍了Python实现的FTP通信客户端与服务器端功能,涉及Python基于socket的端口监听、文件传输等相关操作技巧,需要的朋友可以参考下
    2018-03-03
  • 基于Python实现简易学生信息管理系统

    基于Python实现简易学生信息管理系统

    这篇文章主要为大家详细介绍了python实现简易学生信息管理系统,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-07-07
  • Python 模拟动态产生字母验证码图片功能

    Python 模拟动态产生字母验证码图片功能

    这篇文章主要介绍了Python 模拟动态产生字母验证码图片,这里给大家介绍了pillow模块的使用,需要的朋友可以参考下
    2019-12-12
  • Python3如何日志同时输出到控制台和文件

    Python3如何日志同时输出到控制台和文件

    这篇文章主要介绍了Python3如何日志同时输出到控制台和文件问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-11-11
  • pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异

    pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异

    这篇文章主要介绍了pytho matplotlib工具栏源码探析一之禁用工具栏、默认工具栏和工具栏管理器三种模式的差异,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2021-02-02
  • Python编码类型转换方法详解

    Python编码类型转换方法详解

    这篇文章主要介绍了Python编码类型转换方法,结合实例形式详细分析了Python针对各种常见编码的转码与解码等操作技巧,需要的朋友可以参考下
    2016-07-07
  • python多进程下的生产者和消费者模型

    python多进程下的生产者和消费者模型

    这篇文章主要介绍了python多进程下的生产者和消费者模型,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-05-05
  • Python编译过程和执行原理解析

    Python编译过程和执行原理解析

    这篇文章主要介绍了Python编译过程和执行原理解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2023-07-07
  • Python NumPy教程之数组的基本操作详解

    Python NumPy教程之数组的基本操作详解

    Numpy 中的数组是一个元素表(通常是数字),所有元素类型相同,由正整数元组索引。本文将通过一些示例详细讲一下NumPy中数组的一些基本操作,需要的可以参考一下
    2022-08-08

最新评论