Python反向传播实现线性回归步骤详细讲解

 更新时间:2022年10月18日 09:40:49   作者:Henry_zs  
回归是监督学习的一个重要问题,回归用于预测输入变量和输出变量之间的关系,特别是当输入变量的值发生变化时,输出变量的值也随之发生变化。回归模型正是表示从输入变量到输出变量之间映射的函数

1. 导入包

我们这次的任务是随机生成一些离散的点,然后用直线(y = w *x + b )去拟合

首先看一下我们需要导入的包有

torch 包为我们生成张量,可以使用反向传播

matplotlib.pyplot 包帮助我们绘制曲线,实现可视化

2. 生成数据

这里我们通过rand随机生成数据,因为生成的数据在0~1之间,这里我们扩大10倍。

我们设置的batch_size,也就是数据的个数为20个,所以这里会产生维度是(20,1)个训练样本

我们假设大概的回归是 y = 2 * x + 3 的,为了保证损失不一直为0 ,这里我们添加一点噪音

最后返回x作为输入,y作为真实值label

rand [0,1]均匀分布

如果想要每次产生的随机数是一样的,可以在代码的前面设置一下随机数种子

3. 训练数据

首先,我们要建立的模型是线性的y = w * x + b ,所以我们需要先初始化w ,b

使用randn 标准正态分布随机初始化权重w,将偏置b初始化为0

为什么将权重w随机初始化?

  • 首先,为了抑制过拟合,提高模型的泛化能力,我们可以采用权重衰减来抑制权重w的大小。因为权重过大,对应的输入x的特征就越重要,但是如果对应x是噪音的话,那么系统就会陷入过拟合中。所以我们希望得到的模型曲线是一条光滑的,对输入不敏感的曲线,所以w越小越好
  • 那这样为什么不直接把权重初始化为0,或者说很小很小的数字呢。因为,w太小的话,那么在反向传播的时候,由于我们习惯学习率lr 设置很小,那在更新w的时候基本就不更新了。而不把权重设置为0,是因为无论训练多久,在更新权重的时候,所有权重都会被更新成相同的值,这样多层隐藏层就没有意义了。严格来说,是为了瓦解权重的对称结构

接下来可以训练我们的模型了

1. 将输入的特征x和对应真实值label y通过zip函数打包。将输入x经过模型 w *x + b 的预测输出预测值y

2. 计算损失函数loss,因为之前将w、b都是设置成会计算梯度的,那么loss.backward() 会自动计算w和b的梯度。用w的值data,减去梯度的值grad.data 乘上 学习率lr完成一次更新

3. 当w、b梯度不为零的话,要清零。这里有两种解释,第一种是每次计算完梯度后,值会和之前计算的梯度值进行累加,而我们只是需要当前这步的梯度值,所有我们需要将之前的值清零。第二种是,因为梯度的累加,那么相当于实现一个很大的batch训练。假如一个epoch里面,梯度不进行清零的话,相当于把所有的样本求和后在进行梯度下降,而不是我们原先使用的针对单个样本进行下降的SGD算法

4. 每100次迭代后,我们打印一下损失

4. 绘制图像

scatter 相当于离散点的绘图

要绘制连续的图像,只需要给个定义域然后通过表达式 w * x +b 计算y就可以了,最后输出一下w和b,看看是不是和我们设置的w = 2,b =3 接近

5. 代码

import torch
import matplotlib.pyplot as plt
def trainSet(batch_size = 20):   # 定义训练集
    x = torch.rand(batch_size,1) * 10
    y = x * 2 + 3 + torch.randn(batch_size,1)   # y = x * 2  + 3(近似)
    return x,y
train_x, train_y = trainSet()   # 训练集
w =torch.randn(1,requires_grad= True)
b = torch.zeros(1,requires_grad= True)
lr = 0.001
for epoch in range(1000):
    for x,y in zip(train_x,train_y):  # SGD算法,如果是BSGD的话,不需要这个for
        y_pred = w*x  + b
        loss = (y - y_pred).pow(2) / 2
        loss.backward()
        w.data -= w.grad.data * lr
        b.data -= b.grad.data * lr
        if w.data is not True:   # 梯度值不为零的话,要清零
            w.grad.data.zero_()   #  否则相当于一个大的batch训练
        if b.data is not True:
            b.grad.data.zero_()
    if epoch % 100 ==0:
        print('loss:',loss.data)
plt.scatter(train_x,train_y)
x = torch.arange(0,11).view(-1,1)
y = x * w.data + b.data
plt.plot(x,y)
plt.show()
print(w.data,b.data)

输出的图像

输出的结果为

这里可以看的最后的w = 1.9865和b = 2.9857 和我们设置的2,3是接近的

到此这篇关于Python反向传播实现线性回归步骤详细讲解的文章就介绍到这了,更多相关Python线性回归内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • 在Python中f-string的几个技巧,你都知道吗

    在Python中f-string的几个技巧,你都知道吗

    f-string想必很多Python用户都基础性的使用过,但是百分之九十的人不知道?在Python中f-string的几个技巧,今天就带大家一起看看Python f-string技巧大全,需要的朋友参考下吧
    2021-10-10
  • Python实现在Excel文件中写入图表

    Python实现在Excel文件中写入图表

    这篇文章主要为大家介绍了如何利用Python语言实现在Excel文件中写入一个比较简单的图表,文中的实现方法讲解详细,快动手尝试一下吧
    2022-05-05
  • Python处理文本文件中控制字符的方法

    Python处理文本文件中控制字符的方法

    最近在使用Python的时候遇到过文档中出现控制字符报错的问题。想着总结一下,方便以后需要或这同样遇到问题的朋友,下面这篇文章主要介绍了Python处理文本文件中控制字符的解决方法,需要的朋友可以参考借鉴。
    2017-02-02
  • python实现过滤敏感词

    python实现过滤敏感词

    这篇文章主要介绍了python如何实现过滤敏感词,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下
    2021-05-05
  • 使用python进行拆分大文件的方法

    使用python进行拆分大文件的方法

    今天小编就为大家分享一篇使用python进行拆分大文件的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • python中os库的具体使用

    python中os库的具体使用

    本文介绍了Python中os库的一些常见功能,包括获取和改变工作目录、列出目录内容、创建和删除目录等,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2024-11-11
  • Python selenium爬取微博数据代码实例

    Python selenium爬取微博数据代码实例

    这篇文章主要介绍了Python selenium爬取微博数据代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-05-05
  • Pytest单元测试框架生成HTML测试报告及优化的步骤

    Pytest单元测试框架生成HTML测试报告及优化的步骤

    本文主要介绍了Pytest单元测试框架生成HTML测试报告及优化的步骤,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2022-01-01
  • python持久性管理pickle模块详细介绍

    python持久性管理pickle模块详细介绍

    这篇文章主要介绍了python持久性管理pickle模块详细介绍,本文讲解了什么是持久性、一些经过 pickle 的 Python等内容,并讲给出了18个使用示例,需要的朋友可以参考下
    2015-02-02
  • Matplotlib控制坐标轴刻度间距与标签实例代码

    Matplotlib控制坐标轴刻度间距与标签实例代码

    在matplotlib中,记号是图形两个轴上的小标记,到目前为止,我们让matplotlib处理轴图例上记号的位置,下面这篇文章主要给大家介绍了关于Matplotlib控制坐标轴刻度间距与标签的相关资料,需要的朋友可以参考下
    2021-10-10

最新评论