Pytorch backward报错2次访问计算图需要retain_graph=True的情况详解

 更新时间:2024年02月20日 09:47:04   作者:培之  
这篇文章主要介绍了Pytorch backward报错2次访问计算图需要retain_graph=True的情况,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

backward报错2次访问计算图需要 retain_graph=True 的一种情况

错误代码

错误的原因在于

y1 = 0.5*x*2-1.2*x
y2 = x**3

没有放到循环里面,没有随着 x 的优化而相应变化。

import torch
import numpy as np
import torch.optim as optim

torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([1.0, 2.0, 3.0,4.5], dtype=torch.float32, requires_grad=True, device=device)


y_GT= torch.tensor([10, -20, -30,45], dtype=torch.float32,  device=device)

print(f'x{x}')


optimizer = optim.Adam([x], lr=1)
y1 = 0.5*x*2-1.2*x
y2 = x**3

for i in range(10):

    print(f'{i}: x{x}')
    optimizer.zero_grad()


    loss = (y1+y2-y_GT).mean()
    loss.backward()
    optimizer.step()
    print(f'{i}: x{x}')

正确代码

import torch
import numpy as np
import torch.optim as optim

torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([1.0, 2.0, 3.0,4.5], dtype=torch.float32, requires_grad=True, device=device)


y_GT= torch.tensor([10, -20, -30,45], dtype=torch.float32,  device=device)

print(f'x{x}')


optimizer = optim.Adam([x], lr=1)


for i in range(10):

    print(f'{i}: x{x}')
    optimizer.zero_grad()
    y1 = 0.5*x*2-1.2*x
    y2 = x**3

    loss = (y1+y2-y_GT).mean()
    loss.backward()
    optimizer.step()
    print(f'{i}: x{x}')

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • 使用Python处理PDF文件的实践分享

    使用Python处理PDF文件的实践分享

    在现代数字化时代,PDF(Portable Document Format)文件已经成为广泛使用的电子文档格式,这篇文章主要为分享了Python处理PDF文件的简介与实践,需要的可以参考下
    2023-06-06
  • Python中文件路径的处理方式总结

    Python中文件路径的处理方式总结

    本文详细介绍了Python的os和pathlib模块在文件路径处理中的应用,包括常用函数和类方法,以及它们之间的对比和实例演示,旨在帮助开发者提升文件操作效率和代码可读性,需要的朋友可以参考下
    2025-03-03
  • Python基于pyjnius库实现访问java类

    Python基于pyjnius库实现访问java类

    这篇文章主要介绍了Python基于pyjnius库实现访问java类,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-07-07
  • python多进程中的内存复制(实例讲解)

    python多进程中的内存复制(实例讲解)

    下面小编就为大家分享一篇python多进程中的内存复制(实例讲解),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-01-01
  • 使用Python开发提取word所有表格信息的程序

    使用Python开发提取word所有表格信息的程序

    这篇文章主要为大家详细介绍了如何使用Python开发一个可以提取word所有表格信息的程序,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下
    2025-07-07
  • python自动化测试之如何解析excel文件

    python自动化测试之如何解析excel文件

    这篇文章主要介绍了python自动化测试之如何解析excel文件,今天我们就把不同模块处理excel文件的方法做个总结,直接做封装,方便我们以后直接使用,增加工作效率。,需要的朋友可以参考下
    2019-06-06
  • Win10 GPU运算环境搭建(CUDA10.0+Cudnn 7.6.5+pytroch1.2+tensorflow1.14.0)

    Win10 GPU运算环境搭建(CUDA10.0+Cudnn 7.6.5+pytroch1.2+tensorflow1.

    熟悉深度学习的人都知道,深度学习是需要训练的,本文主要介绍了Win10 GPU运算环境搭建,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-09-09
  • 利用Python破解生日悖论问题

    利用Python破解生日悖论问题

    生日悖论,就是23个人在一个房间,期间必然有两个人生日相同的概率为50%,30个人的话概率是70%,60个人甚至上升到99%。本文就来用Python语言破解这一问题,感兴趣的可以了解一下
    2022-12-12
  • python psutil 模块概述及使用示例

    python psutil 模块概述及使用示例

    psutil是一个跨平台的Python库,用于系统监控、性能分析和进程管理,它提供了丰富的API,可用于获取系统的CPU、内存、磁盘、网络等资源的使用情况,以及进行进程管理,psutil支持Linux、Windows、macOS等主流操作系统
    2024-11-11
  • 给Python中的MySQLdb模块添加超时功能的教程

    给Python中的MySQLdb模块添加超时功能的教程

    这篇文章主要介绍了给Python中的MySQLdb模块添加超时功能的教程,timeout功能在服务器的运维当中非常有用,需要的朋友可以参考下
    2015-05-05

最新评论