pytorch实现梯度下降和反向传播图文详细讲解

 更新时间:2023年04月24日 10:18:10   作者:疯狂的小强呀  
这篇文章主要介绍了pytorch实现梯度下降和反向传播,反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数α的乘积和该数量相对于成本函数的偏导数来更新网络中的权重和偏差

反向传播

这里说一下我的理解,反向传播是相对于前向计算的,以公式J(a,b,c)=3(a+bc)为例,前向计算相当于向右计算J(a,b,c)的值,反向传播相当于反过来通过y求变量a,b,c的导数,如下图

手动完成线性回归

import torch
import numpy as np
from matplotlib import pyplot as plt
"""
假设模型为y=w*x+b
我们给出的训练数据是通过y=3*x+1,得到的,其中w=3,b=1
通过训练y=w*x+b观察训练结果是否接近于w=3,b=1
"""
# 设置学习率
learning_rate=0.01
#准备数据
x=torch.rand(500,1) #随机生成500个x作为训练数据
y_true=x*3+1 #根据模型得到x对应的y的实际值
#初始化参数
w=torch.rand([1,1],requires_grad=True) #初始化w
b=torch.rand(1,requires_grad=True,dtype=torch.float32) #初始化b
#通过循环,反向传播,更新参数
for i in range(2000):
    # 通过模型计算y_predict
    y_predict=torch.matmul(x,w)+b #根据模型得到预测值
    #计算loss
    loss=(y_true-y_predict).pow(2).mean()
    #防止梯度累加,每次计算梯度前都将其置为0
    if w.grad is not None:
        w.grad.data.zero_()
    if b.grad is not None:
        b.grad.data.zero_()
    #通过反向传播,记录梯度
    loss.backward()
    #更新参数
    w.data=w.data-learning_rate*w.grad
    b.data=b.data-learning_rate*b.grad
    # 这里打印部分值看一看变化
    if i%50==0:
        print("w,b,loss:",w.item(),b.item(),loss.item())
#设置图像的大小
plt.figure(figsize=(20,8))
#将真实值用散点表示出来
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
#将预测值用直线表示出来
y_predict=torch.matmul(x,w)+b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c="r")
#显示图像
plt.show()

pytorch API完成线性回归

优化器类

优化器(optimizer),可以理解为torch为我们封装的用来进行更新参数的方法,比如常见的随机梯度下降(stochastic gradient descent,SGD)

优化器类都是由torch.optim提供的,例如

  • torch.optim.SGD(参数,学习率)
  • torch.optim.Adam(参数,学习率)

注意:

  • 参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数
  • 优化类的使用方法

①实例化

②所有参数的梯度,将其置为0

③反向传播计算梯度

④更新参数值

实现

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
# 1.定义数据,给出x
x=torch.rand(50,1)
# 假定模型为y=w*x+b,根据模型给出真实值y=x*3+0.8
y=x*3+0.8
# print(x)
#2.定义模型
class Lr(torch.nn.Module):
    def __init__(self):
        super(Lr, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        out = self.linear(x)
        return out
# 3.实例化模型、loss、优化器
model=Lr()
criterion=nn.MSELoss()
# print(list(model.parameters()))
optimizer=optim.SGD(model.parameters(),lr=1e-3)
# 4.训练模型
for i in range(30000):
    out=model(x) #获取预测值
    loss=criterion(y,out) #计算损失
    optimizer.zero_grad() #梯度归零
    loss.backward() #计算梯度
    optimizer.step() #更新梯度
    if (i+1)%100 ==0:
        print('Epoch[{}/{}],loss:{:.6f}'.format(i,30000,loss.data))
# 5.模型评估
model.eval() #设置模型为评估模式,即预测模式
predict=model(x)
predict=predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

到此这篇关于pytorch实现梯度下降和反向传播图文详细讲解的文章就介绍到这了,更多相关pytorch梯度下降和反向传播内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Django对models里的objects的使用详解

    Django对models里的objects的使用详解

    今天小编就为大家分享一篇Django对models里的objects的使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-08-08
  • python实现逻辑回归的方法示例

    python实现逻辑回归的方法示例

    这篇文章主要介绍了python实现逻辑回归的方法示例,这是机器学习课程的一个实验,整理出来共享给大家,需要的朋友可以参考学习,下来要一起看看吧。
    2017-05-05
  • python如何实现Dice系数

    python如何实现Dice系数

    这篇文章主要介绍了python如何实现Dice系数,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-10-10
  • python psutil库安装教程

    python psutil库安装教程

    这篇文章给大家介绍了python psutil库安装教程,首先要确认本机已安装python环境,具体安装过程大家参考下本文
    2018-03-03
  • springboot aop方式实现接口入参校验的示例代码

    springboot aop方式实现接口入参校验的示例代码

    在实际开发项目中,我们常常需要对接口入参进行校验,本文主要介绍了springboot aop方式实现接口入参校验的示例代码,具有一定的参考价值,感兴趣的可以了解一下
    2023-08-08
  • python输出数学符号实例

    python输出数学符号实例

    这篇文章主要介绍了python输出数学符号实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python调用C语言的实现

    Python调用C语言的实现

    这篇文章主要介绍了Python调用C语言的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • Python利用Selenium实现自动观看学习通视频

    Python利用Selenium实现自动观看学习通视频

    Selenium是一个用于Web应用程序测试的工具。Selenium测试直接运行在浏览器中,就像真正的用户在操作一样。本文主要介绍了利用Selenium实现自动观看学习通视频,需要的同学可以参考一下
    2021-12-12
  • Tensorflow之MNIST CNN实现并保存、加载模型

    Tensorflow之MNIST CNN实现并保存、加载模型

    这篇文章主要为大家详细介绍了Tensorflow之MNIST CNN实现并保存、加载模型,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2020-06-06
  • Python+tkinter编写一个最近很火的强制表白神器

    Python+tkinter编写一个最近很火的强制表白神器

    这篇文章主要为大家详细介绍了Python如何通过tkinter编写一个最近很火的强制表白神器,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起尝试一下
    2023-04-04

最新评论