用Pytorch实现线性回归模型的步骤

 更新时间:2024年01月16日 11:19:18   作者:chairon  
线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,本文主要介绍了如何利用Pytorch实现线性模型,需要的朋友可以参考下

Pytorch实现

步骤

  • 准备数据集
  • 设计模型(计算预测值y_hat):从nn.Module模块继承
  • 构造损失函数和优化器:使用PytorchAPI
  • 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。

在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。

在这里插入图片描述

而使用Pytorch构造模型,重点时在构建计算图和损失函数上。

在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经忘了模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算(backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

代码

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述

SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  • 预测
  • 计算loss
  • 梯度清零
  • Backward
  • 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

#2. Design Model
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。

在这里插入图片描述

比如说:Adam的loss图:

在这里插入图片描述

以上就是用Pytorch实现线性回归模型的步骤的详细内容,更多关于Pytorch线性回归模型的资料请关注脚本之家其它相关文章!

相关文章

  • 使用Python和OpenCV实现实时文档扫描与矫正系统

    使用Python和OpenCV实现实时文档扫描与矫正系统

    在日常工作和学习中,我们经常需要将纸质文档数字化,手动拍摄文档照片常常会出现角度倾斜、透 视变形等问题,影响后续使用,本文将介绍如何使用Python和OpenCV构建一个实时文档扫描与矫正系统,能够通过摄像头自动检测文档边缘并进行透 视变换矫正,需要的朋友可以参考下
    2025-05-05
  • python 爬取百度文库并下载(免费文章限定)

    python 爬取百度文库并下载(免费文章限定)

    这篇文章主要介绍了python 爬取百度文库并下载的示例,帮助大家更好的理解和学习python 爬虫的相关知识,感兴趣的朋友可以了解下
    2020-12-12
  • pandas的连接函数concat()函数的具体使用方法

    pandas的连接函数concat()函数的具体使用方法

    这篇文章主要介绍了pandas的连接函数concat()函数的具体使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2019-07-07
  • python利用xlsxwriter模块 操作 Excel

    python利用xlsxwriter模块 操作 Excel

    这篇文章主要介绍了python利用xlsxwriter模块 操作 Excel,帮助大家更好的利用python处理表格,提高办公效率,感兴趣的朋友可以了解下
    2020-10-10
  • python Tkinter模块使用方法详解

    python Tkinter模块使用方法详解

    Python的GUI库非常多,之所以选择 Tkinter,一是最为简单,二是自带库,不需下载安装,随时使用,跨平台兼容性非常好,下面这篇文章主要给大家介绍了关于python Tkinter模块使用方法的相关资料,需要的朋友可以参考下
    2022-04-04
  • 基于opencv实现手势控制音量(案例详解)

    基于opencv实现手势控制音量(案例详解)

    这篇文章主要介绍了基于opencv的手势控制音量和ai换脸,通过定义了一个名为 handDetector 的类,用于检测和跟踪手部,结合实例代码给大家介绍的非常详细,需要的朋友可以参考下
    2023-08-08
  • Python中@dataclass装饰器实践指南

    Python中@dataclass装饰器实践指南

    这篇文章主要给大家介绍了关于Python@dataclass装饰器的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用Python具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧
    2019-09-09
  • 关于numpy.concatenate()函数的使用及说明

    关于numpy.concatenate()函数的使用及说明

    这篇文章主要介绍了关于numpy.concatenate()函数的使用及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • 深入理解Django自定义信号(signals)

    深入理解Django自定义信号(signals)

    这篇文章主要介绍了深入理解Django自定义信号(signals),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2018-10-10
  • python中*args与**kwarsg及闭包和装饰器的用法

    python中*args与**kwarsg及闭包和装饰器的用法

    这篇文章主要介绍了python中*args与**kwarsg及闭包和装饰器的用法说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
    2022-07-07

最新评论