PyTorch实现线性回归详细过程

 更新时间:2022年03月09日 17:13:45   作者:心️升明月  
本文介绍PyTorch实现线性回归,线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解,需要的朋友可以参考一下

一、实现步骤

1、准备数据

x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

2、设计模型

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
        
model = LinearModel()  

3、构造损失函数和优化器

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

4、训练过程

epoch_list = []
loss_list = []
w_list = []
b_list = []
for epoch in range(1000):
    y_pred = model(x_data)                      # 计算预测值
    loss = criterion(y_pred, y_data)    # 计算损失
    print(epoch,loss)
    
    epoch_list.append(epoch)
    loss_list.append(loss.data.item())
    w_list.append(model.linear.weight.item())
    b_list.append(model.linear.bias.item())
    
    optimizer.zero_grad()   # 梯度归零
    loss.backward()         # 反向传播
    optimizer.step()        # 更新

5、结果展示

展示最终的权重和偏置:

# 输出权重和偏置
print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())

结果为:

w =  1.9998501539230347
b =  0.0003405189490877092

模型测试:

# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data)

y_pred =  tensor([[7.9997]])

分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:

# 二维曲线图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

# 三维散点图
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(w_list,b_list,loss_list,c='r')
#设置坐标轴
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()

结果如下图所示:

 到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

二、参考文献

相关文章

  • 解决Pyinstaller打包软件失败的一个坑

    解决Pyinstaller打包软件失败的一个坑

    这篇文章主要介绍了解决Pyinstaller打包软件失败的一个坑,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-03-03
  • Python中endswith()函数的基本使用

    Python中endswith()函数的基本使用

    这篇文章主要介绍了Python中endswith()函数的基本使用,是Python学习当中的基础知识,该函数可以用来检测文件类型,需要的朋友可以参考下
    2015-04-04
  • 使用Python实现文本转语音(TTS)并播放音频

    使用Python实现文本转语音(TTS)并播放音频

    在开发涉及语音交互或需要语音提示的应用时,文本转语音(TTS)技术是一个非常实用的工具,下面我们来看看如何使用gTTS和playsound库将文本转换为语音并播放音频文件吧
    2025-03-03
  • python实现防截图的6种方法详解

    python实现防截图的6种方法详解

    防截图是指一组技术或方法,用于防止他人在未经允许的情况下在屏幕上截取或记录图像,这是一个重要的安全措施,它可以防止窃取敏感信息或监视个人信息,本文为大家整理了6种python可以防截图的方法,需要的可以参考下
    2023-10-10
  • python 使用Tensorflow训练BP神经网络实现鸢尾花分类

    python 使用Tensorflow训练BP神经网络实现鸢尾花分类

    这篇文章主要介绍了python 使用Tensorflow训练BP神经网络实现鸢尾花分类,帮助大家更好的利用python进行深度学习,感兴趣的朋友可以了解下
    2021-05-05
  • Python调用OpenCV实现图像平滑代码实例

    Python调用OpenCV实现图像平滑代码实例

    这篇文章主要介绍了Python调用OpenCV实现图像平滑代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • Python中的数学运算操作符使用进阶

    Python中的数学运算操作符使用进阶

    这篇文章主要介绍了Python中的数学运算操作符使用进阶,也包括运算赋值操作符等基本知识的小结,需要的朋友可以参考下
    2016-06-06
  • python list的index()和find()的实现

    python list的index()和find()的实现

    这篇文章主要介绍了python list的index()和find()的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-11-11
  • 一篇文章带你了解Python之Selenium自动化爬虫

    一篇文章带你了解Python之Selenium自动化爬虫

    这篇文章主要为大家详细介绍了Python之Selenium自动化爬虫,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2022-01-01
  • Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例

    Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例

    这篇文章主要介绍了Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作,涉及Python基于先序遍历和中序遍历构造二叉树,再后序遍历输出相关操作技巧,需要的朋友可以参考下
    2018-07-07

最新评论